{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Nearest Neighbor in TensorFlow\n", "\n", "Credits: Forked from [TensorFlow-Examples](https://github.com/aymericdamien/TensorFlow-Examples) by Aymeric Damien\n", "\n", "## Setup\n", "\n", "Refer to the [setup instructions](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/Setup_TensorFlow.md)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "# Import MINST data\n", "import input_data\n", "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# In this example, we limit mnist data\n", "Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n", "Xte, Yte = mnist.test.next_batch(200) #200 for testing" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Reshape images to 1D\n", "Xtr = np.reshape(Xtr, newshape=(-1, 28*28))\n", "Xte = np.reshape(Xte, newshape=(-1, 28*28))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# tf Graph Input\n", "xtr = tf.placeholder(\"float\", [None, 784])\n", "xte = tf.placeholder(\"float\", [784])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Nearest Neighbor calculation using L1 Distance\n", "# Calculate L1 Distance\n", "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n", "# Predict: Get min distance index (Nearest neighbor)\n", "pred = tf.arg_min(distance, 0)\n", "\n", "accuracy = 0." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Initializing the variables\n", "init = tf.initialize_all_variables()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test 0 Prediction: 7 True Class: 7\n", "Test 1 Prediction: 2 True Class: 2\n", "Test 2 Prediction: 1 True Class: 1\n", "Test 3 Prediction: 0 True Class: 0\n", "Test 4 Prediction: 4 True Class: 4\n", "Test 5 Prediction: 1 True Class: 1\n", "Test 6 Prediction: 4 True Class: 4\n", "Test 7 Prediction: 9 True Class: 9\n", "Test 8 Prediction: 8 True Class: 5\n", "Test 9 Prediction: 9 True Class: 9\n", "Test 10 Prediction: 0 True Class: 0\n", "Test 11 Prediction: 0 True Class: 6\n", "Test 12 Prediction: 9 True Class: 9\n", "Test 13 Prediction: 0 True Class: 0\n", "Test 14 Prediction: 1 True Class: 1\n", "Test 15 Prediction: 5 True Class: 5\n", "Test 16 Prediction: 4 True Class: 9\n", "Test 17 Prediction: 7 True Class: 7\n", "Test 18 Prediction: 3 True Class: 3\n", "Test 19 Prediction: 4 True Class: 4\n", "Test 20 Prediction: 9 True Class: 9\n", "Test 21 Prediction: 6 True Class: 6\n", "Test 22 Prediction: 6 True Class: 6\n", "Test 23 Prediction: 5 True Class: 5\n", "Test 24 Prediction: 4 True Class: 4\n", "Test 25 Prediction: 0 True Class: 0\n", "Test 26 Prediction: 7 True Class: 7\n", "Test 27 Prediction: 4 True Class: 4\n", "Test 28 Prediction: 0 True Class: 0\n", "Test 29 Prediction: 1 True Class: 1\n", "Test 30 Prediction: 3 True Class: 3\n", "Test 31 Prediction: 1 True Class: 1\n", "Test 32 Prediction: 3 True Class: 3\n", "Test 33 Prediction: 4 True Class: 4\n", "Test 34 Prediction: 7 True Class: 7\n", "Test 35 Prediction: 2 True Class: 2\n", "Test 36 Prediction: 7 True Class: 7\n", "Test 37 Prediction: 1 True Class: 1\n", "Test 38 Prediction: 2 True Class: 2\n", "Test 39 Prediction: 1 True Class: 1\n", "Test 40 Prediction: 1 True Class: 1\n", "Test 41 Prediction: 7 True Class: 7\n", "Test 42 Prediction: 4 True Class: 4\n", "Test 43 Prediction: 1 True Class: 2\n", "Test 44 Prediction: 3 True Class: 3\n", "Test 45 Prediction: 5 True Class: 5\n", "Test 46 Prediction: 1 True Class: 1\n", "Test 47 Prediction: 2 True Class: 2\n", "Test 48 Prediction: 4 True Class: 4\n", "Test 49 Prediction: 4 True Class: 4\n", "Test 50 Prediction: 6 True Class: 6\n", "Test 51 Prediction: 3 True Class: 3\n", "Test 52 Prediction: 5 True Class: 5\n", "Test 53 Prediction: 5 True Class: 5\n", "Test 54 Prediction: 6 True Class: 6\n", "Test 55 Prediction: 0 True Class: 0\n", "Test 56 Prediction: 4 True Class: 4\n", "Test 57 Prediction: 1 True Class: 1\n", "Test 58 Prediction: 9 True Class: 9\n", "Test 59 Prediction: 5 True Class: 5\n", "Test 60 Prediction: 7 True Class: 7\n", "Test 61 Prediction: 8 True Class: 8\n", "Test 62 Prediction: 9 True Class: 9\n", "Test 63 Prediction: 3 True Class: 3\n", "Test 64 Prediction: 7 True Class: 7\n", "Test 65 Prediction: 4 True Class: 4\n", "Test 66 Prediction: 6 True Class: 6\n", "Test 67 Prediction: 4 True Class: 4\n", "Test 68 Prediction: 3 True Class: 3\n", "Test 69 Prediction: 0 True Class: 0\n", "Test 70 Prediction: 7 True Class: 7\n", "Test 71 Prediction: 0 True Class: 0\n", "Test 72 Prediction: 2 True Class: 2\n", "Test 73 Prediction: 7 True Class: 9\n", "Test 74 Prediction: 1 True Class: 1\n", "Test 75 Prediction: 7 True Class: 7\n", "Test 76 Prediction: 3 True Class: 3\n", "Test 77 Prediction: 7 True Class: 2\n", "Test 78 Prediction: 9 True Class: 9\n", "Test 79 Prediction: 7 True Class: 7\n", "Test 80 Prediction: 7 True Class: 7\n", "Test 81 Prediction: 6 True Class: 6\n", "Test 82 Prediction: 2 True Class: 2\n", "Test 83 Prediction: 7 True Class: 7\n", "Test 84 Prediction: 8 True Class: 8\n", "Test 85 Prediction: 4 True Class: 4\n", "Test 86 Prediction: 7 True Class: 7\n", "Test 87 Prediction: 3 True Class: 3\n", "Test 88 Prediction: 6 True Class: 6\n", "Test 89 Prediction: 1 True Class: 1\n", "Test 90 Prediction: 3 True Class: 3\n", "Test 91 Prediction: 6 True Class: 6\n", "Test 92 Prediction: 9 True Class: 9\n", "Test 93 Prediction: 3 True Class: 3\n", "Test 94 Prediction: 1 True Class: 1\n", "Test 95 Prediction: 4 True Class: 4\n", "Test 96 Prediction: 1 True Class: 1\n", "Test 97 Prediction: 7 True Class: 7\n", "Test 98 Prediction: 6 True Class: 6\n", "Test 99 Prediction: 9 True Class: 9\n", "Test 100 Prediction: 6 True Class: 6\n", "Test 101 Prediction: 0 True Class: 0\n", "Test 102 Prediction: 5 True Class: 5\n", "Test 103 Prediction: 4 True Class: 4\n", "Test 104 Prediction: 9 True Class: 9\n", "Test 105 Prediction: 9 True Class: 9\n", "Test 106 Prediction: 2 True Class: 2\n", "Test 107 Prediction: 1 True Class: 1\n", "Test 108 Prediction: 9 True Class: 9\n", "Test 109 Prediction: 4 True Class: 4\n", "Test 110 Prediction: 8 True Class: 8\n", "Test 111 Prediction: 7 True Class: 7\n", "Test 112 Prediction: 3 True Class: 3\n", "Test 113 Prediction: 9 True Class: 9\n", "Test 114 Prediction: 7 True Class: 7\n", "Test 115 Prediction: 9 True Class: 4\n", "Test 116 Prediction: 9 True Class: 4\n", "Test 117 Prediction: 4 True Class: 4\n", "Test 118 Prediction: 9 True Class: 9\n", "Test 119 Prediction: 7 True Class: 2\n", "Test 120 Prediction: 5 True Class: 5\n", "Test 121 Prediction: 4 True Class: 4\n", "Test 122 Prediction: 7 True Class: 7\n", "Test 123 Prediction: 6 True Class: 6\n", "Test 124 Prediction: 7 True Class: 7\n", "Test 125 Prediction: 9 True Class: 9\n", "Test 126 Prediction: 0 True Class: 0\n", "Test 127 Prediction: 5 True Class: 5\n", "Test 128 Prediction: 8 True Class: 8\n", "Test 129 Prediction: 5 True Class: 5\n", "Test 130 Prediction: 6 True Class: 6\n", "Test 131 Prediction: 6 True Class: 6\n", "Test 132 Prediction: 5 True Class: 5\n", "Test 133 Prediction: 7 True Class: 7\n", "Test 134 Prediction: 8 True Class: 8\n", "Test 135 Prediction: 1 True Class: 1\n", "Test 136 Prediction: 0 True Class: 0\n", "Test 137 Prediction: 1 True Class: 1\n", "Test 138 Prediction: 6 True Class: 6\n", "Test 139 Prediction: 4 True Class: 4\n", "Test 140 Prediction: 6 True Class: 6\n", "Test 141 Prediction: 7 True Class: 7\n", "Test 142 Prediction: 2 True Class: 3\n", "Test 143 Prediction: 1 True Class: 1\n", "Test 144 Prediction: 7 True Class: 7\n", "Test 145 Prediction: 1 True Class: 1\n", "Test 146 Prediction: 8 True Class: 8\n", "Test 147 Prediction: 2 True Class: 2\n", "Test 148 Prediction: 0 True Class: 0\n", "Test 149 Prediction: 1 True Class: 2\n", "Test 150 Prediction: 9 True Class: 9\n", "Test 151 Prediction: 9 True Class: 9\n", "Test 152 Prediction: 5 True Class: 5\n", "Test 153 Prediction: 5 True Class: 5\n", "Test 154 Prediction: 1 True Class: 1\n", "Test 155 Prediction: 5 True Class: 5\n", "Test 156 Prediction: 6 True Class: 6\n", "Test 157 Prediction: 0 True Class: 0\n", "Test 158 Prediction: 3 True Class: 3\n", "Test 159 Prediction: 4 True Class: 4\n", "Test 160 Prediction: 4 True Class: 4\n", "Test 161 Prediction: 6 True Class: 6\n", "Test 162 Prediction: 5 True Class: 5\n", "Test 163 Prediction: 4 True Class: 4\n", "Test 164 Prediction: 6 True Class: 6\n", "Test 165 Prediction: 5 True Class: 5\n", "Test 166 Prediction: 4 True Class: 4\n", "Test 167 Prediction: 5 True Class: 5\n", "Test 168 Prediction: 1 True Class: 1\n", "Test 169 Prediction: 4 True Class: 4\n", "Test 170 Prediction: 9 True Class: 4\n", "Test 171 Prediction: 7 True Class: 7\n", "Test 172 Prediction: 2 True Class: 2\n", "Test 173 Prediction: 3 True Class: 3\n", "Test 174 Prediction: 2 True Class: 2\n", "Test 175 Prediction: 1 True Class: 7\n", "Test 176 Prediction: 1 True Class: 1\n", "Test 177 Prediction: 8 True Class: 8\n", "Test 178 Prediction: 1 True Class: 1\n", "Test 179 Prediction: 8 True Class: 8\n", "Test 180 Prediction: 1 True Class: 1\n", "Test 181 Prediction: 8 True Class: 8\n", "Test 182 Prediction: 5 True Class: 5\n", "Test 183 Prediction: 0 True Class: 0\n", "Test 184 Prediction: 2 True Class: 8\n", "Test 185 Prediction: 9 True Class: 9\n", "Test 186 Prediction: 2 True Class: 2\n", "Test 187 Prediction: 5 True Class: 5\n", "Test 188 Prediction: 0 True Class: 0\n", "Test 189 Prediction: 1 True Class: 1\n", "Test 190 Prediction: 1 True Class: 1\n", "Test 191 Prediction: 1 True Class: 1\n", "Test 192 Prediction: 0 True Class: 0\n", "Test 193 Prediction: 4 True Class: 9\n", "Test 194 Prediction: 0 True Class: 0\n", "Test 195 Prediction: 1 True Class: 3\n", "Test 196 Prediction: 1 True Class: 1\n", "Test 197 Prediction: 6 True Class: 6\n", "Test 198 Prediction: 4 True Class: 4\n", "Test 199 Prediction: 2 True Class: 2\n", "Done!\n", "Accuracy: 0.92\n" ] } ], "source": [ "# Launch the graph\n", "with tf.Session() as sess:\n", " sess.run(init)\n", "\n", " # loop over test data\n", " for i in range(len(Xte)):\n", " # Get nearest neighbor\n", " nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})\n", " # Get nearest neighbor class label and compare it to its true label\n", " print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n", " \"True Class:\", np.argmax(Yte[i])\n", " # Calculate accuracy\n", " if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n", " accuracy += 1./len(Xte)\n", " print \"Done!\"\n", " print \"Accuracy:\", accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.3" } }, "nbformat": 4, "nbformat_minor": 0 }