{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## k-nearest-neighbor algorithm in plain Python\n", "\n", "The k-nn algorithm is a simple **supervised** machine learning algorithm that can be used both for classification and regression. It's an **instance-based** algorithm. So instead of estimating a model, it stores all training examples in memory and makes predictions using a similarity measure. \n", "\n", "Given an input example, the k-nn algorithm retrieves the k most similar instances from memory. Similarity is defined in terms of distance, that is, the training examples with the smallest (euclidean) distance to the input example are considered to be most similar.\n", "\n", "The target value of the input example is computed as follows: \n", " \n", "Classification: \n", "a) unweighted: output the most common classification among the k-nearest neighbors \n", "b) weighted: sum up the weights of the k-nearest neighbors for each classification value, output classification with highest weight \n", " \n", "Regression: \n", "a) unweighted: output the average of the values of the k-nearest neighbors \n", "b) weighted: for all classification values, sum up classification value$*$weight and divide the result trough the sum of all weights \n", "\n", "The weighted k-nn version is a refined version of the algorithm in which the contribution of each neighbor is *weighted* according to its distance to the query point. Below, we implement the basic unweighted version of the k-nn algorithm for the digits dataset from sklearn." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-03-26T14:32:41.915819Z", "start_time": "2018-03-26T14:32:41.094749Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.datasets import load_digits\n", "from sklearn.model_selection import train_test_split\n", "np.random.seed(123)\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2018-03-26T14:33:11.784085Z", "start_time": "2018-03-26T14:33:10.849626Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X_train shape: (1347, 64)\n", "y_train shape: (1347,)\n", "X_test shape: (450, 64)\n", "y_test shape: (450,)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAFwCAYAAACl9k+2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAGsZJREFUeJzt3V/InuV9B/Dfb3FCV52xrK6g4r/WQU+MMxSGMGOno/vDkoNZLGwknsSTDgODNTsynulBZzwYw9DVCOsQ7NakjNJOma9jJ8WkvuI0rdgQMXXDSpM4VmjQXTtIWv806Xtdb5/7ed/n5+cDUt/k9z73deeb+/Hb+3mf58rWWgAA8Iv9ylovAABgEShNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOlwwxYNm5qQfM37ppZcOzV9++eVD82+++ebQfETED37wg6H5t99+e/gYI1prOYvHmTrLUddff/3Q/AUXjP8VH83y1KlTw8cYUTXLiy66aGj+4x//+PAxfvzjHw/Nv/TSS8PHGPRGa+2js3igqfP82Mc+NjQ/+jz7k5/8ZGg+IuLIkSND855nV2fDhg1D81dfffXwMb7//e8Pf8/Euq7NSUrT1G677bah+fvvv39o/sknnxyaj4jYvXv30PyJEyeGj0HEvn37huY3btw4fIx77713aP7gwYPDxyBi8+bNQ/MHDhwYPsby8vLQ/JYtW4aPMeiVqQ8wK9u3bx+aH32ePXr06NB8xPjfGc+zq3PxxRcPzX/xi18cPsa2bduGv2diXdeml+cAADp0labM/Exmfi8zX87MsVsqrCuyrEWedciyFnnWtGJpyswNEfG3EfEHEfHJiPhcZn5y6oUxe7KsRZ51yLIWedbVc6fpUxHxcmvtaGvtdEQ8FhFbp10WE5FlLfKsQ5a1yLOontJ0eUS8+q6vj5/9tffIzJ2ZeSgzD81qccycLGtZMU9ZLgzXZi2uzaJ63j13rrdU/tzbI1tr+yJiX8T6e/skPyPLWlbMU5YLw7VZi2uzqJ47Tccj4sp3fX1FRLw2zXKYmCxrkWcdsqxFnkX1lKZnIuITmXlNZl4YEXdGxNenXRYTkWUt8qxDlrXIs6gVX55rrb2VmZ+PiG9FxIaI+HJr7YXJV8bMybIWedYhy1rkWVfXJ4K31r4REd+YeC3MgSxrkWcdsqxFnjUt5DYqox/Xf+211w7Nj+5tFxHxox/9aGj+s5/97ND8448/PjRf1cmTJ4fmb7nlluFj3HrrrUPztlE5Y9OmTUPzTz311ND8avb4W82eWFWNPm/ecccdQ/N333330PzDDz88NB8RcdNNNw3Nr2ZLLCJ27NgxND+6XdEis40KAEAHpQkAoIPSBADQQWkCAOigNAEAdFCaAAA6KE0AAB2UJgCADkoTAEAHpQkAoIPSBADQYV3sPTe6n9DoXnLXXXfd0PzRo0eH5iMinnjiiaH50XOuuvfc6H5lW7ZsmWYh7/JB2kdplrZt2zY0/9xzzw3NHzhwYGg+IuLee+8d/p6q9u3bNzT/wAMPDM0fOnRoaH41z7P2kludjRs3Ds2P7j23d+/eofmI6feFPHbs2CSP604TAEAHpQkAoMOKpSkzr8zMpzLzSGa+kJn3zGNhzJ4sa5FnHbKsRZ519fxM01sR8Zette9k5sURcTgzn2itvTjx2pg9WdYizzpkWYs8i1rxTlNr7b9aa985++//ExFHIuLyqRfG7MmyFnnWIcta5FnX0LvnMvPqiLgxIr59jt/bGRE7Z7IqJifLWs6XpywXj2uzFtdmLd2lKTMvioh/iohdrbU33//7rbV9EbHv7Gyb2QqZOVnW8ovylOVicW3W4tqsp+vdc5n5q3Em+K+01v552iUxJVnWIs86ZFmLPGvqefdcRsTfR8SR1trfTL8kpiLLWuRZhyxrkWddPXeabo6IP4+IT2fm8tl//nDidTENWdYizzpkWYs8i1rxZ5paa/8RETmHtTAxWdYizzpkWYs861oXe89deumlQ/OHDx8eml/NHkejRtdU1a5du4bm9+zZMzR/ySWXDM2vxtLS0uTHqGh0/6nRvaFWs7/VwYMHh7+nqtHnwdE9PkfnV7OP3Oh/K06cODF8jIpG95Ib3Rdu//79Q/MR49fzyZMnh+ZH/9vSyzYqAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAh4Xce241exZNzZ5IZ4zuJzS6Z9E8/tw2btw4+TEWweifw+i+g9u2bRuaX43RPbd4x+hedR/5yEeG5p944omh+dV8z+233z40vyjPy1u3bh2af/DBB4fmH3300aH51bjnnnuG5u+6666JVjLGnSYAgA5KEwBAh+7SlJkbMvPZzPyXKRfE9GRZizzrkGUt8qxn5E7TPRFxZKqFMFeyrEWedciyFnkW01WaMvOKiPijiPjStMtharKsRZ51yLIWedbUe6dpb0T8VUT834RrYT5kWYs865BlLfIsaMXSlJl/HBGvt9YOrzC3MzMPZeahma2OmZJlLT15ynIxuDZrcW3W1XOn6eaI+JPMPBYRj0XEpzPzH94/1Frb11rb3FrbPOM1MjuyrGXFPGW5MFybtbg2i1qxNLXW/rq1dkVr7eqIuDMi/q219meTr4yZk2Ut8qxDlrXIsy6f0wQA0GFoG5XW2lJELE2yEuZKlrXIsw5Z1iLPWtxpAgDosC427B3dJPGmm26aaCVnjG6+GzG+pscff3z4GMzHpk2bhuaXl5cnWsna2rNnz9D86Aaco1azwe/JkycnWAnnMvo8PrqZbkTEww8/PDT/hS98YWh+9+7dQ/Nr5dSpU5POb9++fWh+9DlzNQ4cODD5MXq40wQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRYF3vPHT16dGh+dJ+3O+64Y9L51XjggQcmPwb8Mvbv3z80v2XLlqH5G264YWh+NXtPHTx4cGj+kUcemfTxF8n9998/NP/kk08Oza9mj8/bbrttaL7qHp9LS0tD8xs3bhyaH91LbnQ9ERGPPvro0Px62UfSnSYAgA5dpSkzN2bmVzPzu5l5JDN/Z+qFMQ1Z1iLPOmRZizxr6n157qGI+GZr7U8z88KI+LUJ18S0ZFmLPOuQZS3yLGjF0pSZvx4RvxsROyIiWmunI+L0tMtiCrKsRZ51yLIWedbV8/LctRHxw4h4JDOfzcwvZeaHJ14X05BlLfKsQ5a1yLOontJ0QUT8dkT8XWvtxoj434jY/f6hzNyZmYcy89CM18jsyLKWFfOU5cJwbdbi2iyqpzQdj4jjrbVvn/36q3HmL8N7tNb2tdY2t9Y2z3KBzJQsa1kxT1kuDNdmLa7NolYsTa21/46IVzPzt87+0u9FxIuTropJyLIWedYhy1rkWVfvu+f+IiK+cvYdAEcj4q7plsTEZFmLPOuQZS3yLKirNLXWliPCLcQCZFmLPOuQZS3yrMknggMAdFjIved27/65N5X8QqN7KB0+fHhoPiJi82b/h2I1RvcTGt3ra+vWrUPzEeN7qI3u0bYolpeXh+ZH96sand+zZ8/QfMR4/seOHRuar7z33IkTJ4bmH3744YlW8o7RveTuvvvuiVZS2+jz8iWXXDJ8jEV93nSnCQCgg9IEANBBaQIA6KA0AQB0UJoAADooTQAAHZQmAIAOShMAQAelCQCgg9IEANBBaQIA6JCttdk/aOYPI+KVc/zWb0TEGzM/4Pq1Vud7VWvto7N4IFn+jCxrkWcdsqxlXec5SWk678EyD7XWPjA721Y+38rndi6Vz7fyuZ1P5XOufG7nUvl8K5/b+az3c/byHABAB6UJAKDDvEvTvjkfb61VPt/K53Yulc+38rmdT+Vzrnxu51L5fCuf2/ms63Oe6880AQAsKi/PAQB0mEtpyszPZOb3MvPlzNw9j2Oupcw8lpnPZ+ZyZh5a6/XMmjzrkGUdH7QsI+RZyaJkOfnLc5m5ISJeiojbI+J4RDwTEZ9rrb046YHXUGYei4jNrbVyn68hzzpkWccHMcsIeVayKFnO407TpyLi5dba0dba6Yh4LCK2zuG4TEOedciyDlnWIs91ah6l6fKIePVdXx8/+2uVtYj418w8nJk713oxMybPOmRZxwcxywh5VrIQWV4wh2PkOX6t+lv2bm6tvZaZl0XEE5n53dbav6/1omZEnnXylKUsF50861iILOdxp+l4RFz5rq+viIjX5nDcNdNae+3s/74eEV+LM7daq5BnnTxlKcuFJs86FiXLeZSmZyLiE5l5TWZeGBF3RsTX53DcNZGZH87Mi3/67xHx+xHxn2u7qpmSZ508ZSnLhSXPOhYpy8lfnmutvZWZn4+Ib0XEhoj4cmvthamPu4Z+MyK+lpkRZ/58/7G19s21XdLsyLNOnrKU5YKTZx0Lk6VPBAcA6OATwQEAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKDDBVM8aGa2KR73p66//vqh+dOnTw/NHzt2bGh+PWqt5SweZ+osR41mf8EF43/FX3zxxeHvmdKiZHnZZZcNzW/YsGFo/tJLLx2aj4j40Ic+NDT/9ttvD80///zzo4//Rmvto0PfdB5T53nllVcOzW/cuHFo/o033hiaj4h4/fXXh+ZH8xy1KNfmddddNzQ/em2+9NJLQ/PrVNe1ma3NPqup/wIsLS0NzY+WoB07dgzNr0eLcjGPGs1+9Ik8ImLTpk3D3zOlRcly165dQ/Oj2Wzbtm1oPiLihhtuGJo/derU0PzVV189NH/y5MnDrbXNQ990HlPnuXfv3qH50Xz2798/NB8xvqaTJ08OH2PEolybBw4cGJofvTa3bNkyNL9OdV2bXS/PZeZnMvN7mflyZu7+5dfGWpFlLfKsQ5a1yLOmFUtTZm6IiL+NiD+IiE9GxOcy85NTL4zZk2Ut8qxDlrXIs66eO02fioiXW2tHW2unI+KxiNg67bKYiCxrkWcdsqxFnkX1lKbLI+LVd319/OyvsXhkWYs865BlLfIsquetRef6Qbef+6G1zNwZETt/6RUxJVnWsmKeslwYrs1aXJtF9ZSm4xHx7veeXhERr71/qLW2LyL2Ray/d1zxM7KsZcU8ZbkwXJu1uDaL6nl57pmI+ERmXpOZF0bEnRHx9WmXxURkWYs865BlLfIsasU7Ta21tzLz8xHxrYjYEBFfbq29MPnKmDlZ1iLPOmRZizzr6vq45NbaNyLiGxOvhTmQZS3yrEOWtcizpkm2UZna6Kfw3nLLLUPz27dvH5qPiHjllVeG5kfPoaqtW8fehTua5X333Tc0z/yMflrz6CeOr+Z7Rj8JeepPnF5LU38y/mp2Xhj95Okin1T9c0b/+zH6PDtqNTuLPPfcc0Pz62WnBhv2AgB0UJoAADooTQAAHZQmAIAOShMAQAelCQCgg9IEANBBaQIA6KA0AQB0UJoAADooTQAAHRZy77nR/Z6uuuqqoflTp04NzUdELC0tDc3b4+qMqfeGO3DgwKSPzzv27t076ePv2bNn+HtG9+iqulfZaiwvLw/NHzt2bGh+NXvPjT4PjuY5+jy+Vkb/+zHq6aefHpofzT5ica81d5oAADooTQAAHVYsTZl5ZWY+lZlHMvOFzLxnHgtj9mRZizzrkGUt8qyr52ea3oqIv2ytfSczL46Iw5n5RGvtxYnXxuzJshZ51iHLWuRZ1Ip3mlpr/9Va+87Zf/+fiDgSEZdPvTBmT5a1yLMOWdYiz7qGfqYpM6+OiBsj4ttTLIb5kWUt8qxDlrXIs5bujxzIzIsi4p8iYldr7c1z/P7OiNg5w7UxEVnW8ovylOVicW3W4tqsp6s0Zeavxpngv9Ja++dzzbTW9kXEvrPzbWYrZKZkWctKecpycbg2a3Ft1tTz7rmMiL+PiCOttb+ZfklMRZa1yLMOWdYiz7p6fqbp5oj484j4dGYun/3nDydeF9OQZS3yrEOWtcizqBVfnmut/UdE5BzWwsRkWYs865BlLfKsayH3nhvd5+aGG24Ymr/kkkuG5iPG92mqupfcqNE9lJ577rmh+dFceMfo3lBT7yW1a9euSR8/ImLbtm1D8/v3759mIevA6Lk9++yzQ/Oj+wJGjD9vrmZPtEUw9XmNXger2eNz6v3zpmIbFQCADkoTAEAHpQkAoIPSBADQQWkCAOigNAEAdFCaAAA6KE0AAB2UJgCADkoTAEAHpQkAoMNC7j03ui/O6J5YmzZtGpqPiHjwwQeHv2fE3r17J338tTK6/9Donkur2a9sdB8l+1udMXrdTL1XXcT4c8XS0tI0C1lAU+8Ndssttwx/zzXXXDM0X/XaHN2Db3TPzhMnTgzNP/TQQ0PzEePPF6N7FU6VvTtNAAAdlCYAgA7dpSkzN2Tms5n5L1MuiOnJshZ51iHLWuRZz8idpnsi4shUC2GuZFmLPOuQZS3yLKarNGXmFRHxRxHxpWmXw9RkWYs865BlLfKsqfdO096I+KuI+L8J18J8yLIWedYhy1rkWdCKpSkz/zgiXm+tHV5hbmdmHsrMQzNbHTMly1p68pTlYnBt1uLarKvnTtPNEfEnmXksIh6LiE9n5j+8f6i1tq+1trm1tnnGa2R2ZFnLinnKcmG4NmtxbRa1Ymlqrf11a+2K1trVEXFnRPxba+3PJl8ZMyfLWuRZhyxrkWddPqcJAKDD0DYqrbWliFiaZCXMlSxrkWcdsqxFnrW40wQA0GEhN+wdtR434RzdfLCq0U0VRzf5XM2mo6ObL994441D88vLy0Pza2U0m9HNcVtrkz5+xPq89tfK6AapTz311ND8fffdNzS/mufA0c20R//OVN3gdzT70fl5PKeNblq/mueLHu40AQB0UJoAADooTQAAHZQmAIAOShMAQAelCQCgg9IEANBBaQIA6KA0AQB0UJoAADooTQAAHRZy77mtW7cOzZ86dWpofs+ePUPzqzG6h1JV+/fvH5of3RduNXtJje6JNbrH0aLsPTdqdG+o0evy6aefHprnvUavhdF8RvNfzd5zzz777ND8jh07hubn8dy/CEafo0azjxjPZqq95Ea50wQA0KGrNGXmxsz8amZ+NzOPZObvTL0wpiHLWuRZhyxrkWdNvS/PPRQR32yt/WlmXhgRvzbhmpiWLGuRZx2yrEWeBa1YmjLz1yPidyNiR0REa+10RJyedllMQZa1yLMOWdYiz7p6Xp67NiJ+GBGPZOazmfmlzPzwxOtiGrKsRZ51yLIWeRbVU5ouiIjfjoi/a63dGBH/GxG73z+UmTsz81BmHprxGpkdWdayYp6yXBiuzVpcm0X1lKbjEXG8tfbts19/Nc78ZXiP1tq+1trm1trmWS6QmZJlLSvmKcuF4dqsxbVZ1IqlqbX23xHxamb+1tlf+r2IeHHSVTEJWdYizzpkWYs86+p999xfRMRXzr4D4GhE3DXdkpiYLGuRZx2yrEWeBXWVptbackS4hViALGuRZx2yrEWeNflEcACADgu599ytt946NH/PPfdMtJJ3PProo0PzS0tL0yxkwYzuPTe6X9Xo/kYR49nYR/CMLVu2DM1v3759aP7kyZND87zX6J/f6HVw4sSJofnRve0iIg4ePDg0v5o90Soa/XPYtGnT0PzGjRuH5iPGny/Wy56d7jQBAHRQmgAAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHRQmgAAOihNAAAdsrU2+wfN/GFEvHKO3/qNiHhj5gdcv9bqfK9qrX10Fg8ky5+RZS3yrEOWtazrPCcpTec9WOah1trmuR1wjVU+38rndi6Vz7fyuZ1P5XOufG7nUvl8K5/b+az3c/byHABAB6UJAKDDvEvTvjkfb61VPt/K53Yulc+38rmdT+Vzrnxu51L5fCuf2/ms63Oe6880AQAsKi/PAQB0mEtpyszPZOb3MvPlzNw9j2Oupcw8lpnPZ+ZyZh5a6/XMmjzrkGUdH7QsI+RZyaJkOfnLc5m5ISJeiojbI+J4RDwTEZ9rrb046YHXUGYei4jNrbVyn68hzzpkWccHMcsIeVayKFnO407TpyLi5dba0dba6Yh4LCK2zuG4TEOedciyDlnWIs91ah6l6fKIePVdXx8/+2uVtYj418w8nJk713oxMybPOmRZxwcxywh5VrIQWV4wh2PkOX6t+lv2bm6tvZaZl0XEE5n53dbav6/1omZEnnXylKUsF50861iILOdxp+l4RFz5rq+viIjX5nDcNdNae+3s/74eEV+LM7daq5BnnTxlKcuFJs86FiXLeZSmZyLiE5l5TWZeGBF3RsTX53DcNZGZH87Mi3/67xHx+xHxn2u7qpmSZ508ZSnLhSXPOhYpy8lfnmutvZWZn4+Ib0XEhoj4cmvthamPu4Z+MyK+lpkRZ/58/7G19s21XdLsyLNOnrKU5YKTZx0Lk6VPBAcA6OATwQEAOihNAAAdlCYAgA5KEwBAB6UJAKCD0gQA0EFpAgDooDQBAHT4f2BS5OfAvcPqAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# We will use the digits dataset as an example. It consists of the 1797 images of hand-written digits. Each digit is\n", "# represented by a 64-dimensional vector of pixel values.\n", "\n", "digits = load_digits()\n", "X, y = digits.data, digits.target\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", "print(f'X_train shape: {X_train.shape}')\n", "print(f'y_train shape: {y_train.shape}')\n", "print(f'X_test shape: {X_test.shape}')\n", "print(f'y_test shape: {y_test.shape}')\n", "\n", "# Example digits\n", "fig = plt.figure(figsize=(10,8))\n", "for i in range(10):\n", " ax = fig.add_subplot(2, 5, i+1)\n", " plt.imshow(X[i].reshape((8,8)), cmap='gray')\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## k-nearest-neighbor class" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2018-03-26T14:33:23.932644Z", "start_time": "2018-03-26T14:33:23.803735Z" } }, "outputs": [], "source": [ "class kNN():\n", " def __init__(self):\n", " pass\n", "\n", " def fit(self, X, y):\n", " self.data = X\n", " self.targets = y\n", "\n", " def euclidean_distance(self, X):\n", " \"\"\"\n", " Computes the euclidean distance between the training data and\n", " a new input example or matrix of input examples X\n", " \"\"\"\n", " # input: single data point\n", " if X.ndim == 1:\n", " l2 = np.sqrt(np.sum((self.data - X)**2, axis=1))\n", "\n", " # input: matrix of data points\n", " if X.ndim == 2:\n", " n_samples, _ = X.shape\n", " l2 = [np.sqrt(np.sum((self.data - X[i])**2, axis=1)) for i in range(n_samples)]\n", "\n", " return np.array(l2)\n", "\n", " def predict(self, X, k=1):\n", " \"\"\"\n", " Predicts the classification for an input example or matrix of input examples X\n", " \"\"\"\n", " # step 1: compute distance between input and training data\n", " dists = self.euclidean_distance(X)\n", "\n", " # step 2: find the k nearest neighbors and their classifications\n", " if X.ndim == 1:\n", " if k == 1:\n", " nn = np.argmin(dists)\n", " return self.targets[nn]\n", " else:\n", " knn = np.argsort(dists)[:k]\n", " y_knn = self.targets[knn]\n", " max_vote = max(y_knn, key=list(y_knn).count)\n", " return max_vote\n", "\n", " if X.ndim == 2:\n", " knn = np.argsort(dists)[:, :k]\n", " y_knn = self.targets[knn]\n", " if k == 1:\n", " return y_knn.T\n", " else:\n", " n_samples, _ = X.shape\n", " max_votes = [max(y_knn[i], key=list(y_knn[i]).count) for i in range(n_samples)]\n", " return max_votes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initializing and training the model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2018-03-26T14:33:34.324040Z", "start_time": "2018-03-26T14:33:34.282266Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing one datapoint, k=1\n", "Predicted label: 8\n", "True label: 8\n", "\n", "Testing one datapoint, k=5\n", "Predicted label: 3\n", "True label: 3\n", "\n", "Testing 10 datapoint, k=1\n", "Predicted labels: [[5 4 5 5 6 6 1 0 8 8]]\n", "True labels: [5 4 5 5 6 6 1 0 8 8]\n", "\n", "Testing 10 datapoint, k=4\n", "Predicted labels: [5, 4, 5, 5, 6, 6, 1, 0, 8, 8]\n", "True labels: [5 4 5 5 6 6 1 0 8 8]\n", "\n" ] } ], "source": [ "knn = kNN()\n", "knn.fit(X_train, y_train)\n", "\n", "print(\"Testing one datapoint, k=1\")\n", "print(f\"Predicted label: {knn.predict(X_test[0], k=1)}\")\n", "print(f\"True label: {y_test[0]}\")\n", "print()\n", "print(\"Testing one datapoint, k=5\")\n", "print(f\"Predicted label: {knn.predict(X_test[20], k=5)}\")\n", "print(f\"True label: {y_test[20]}\")\n", "print()\n", "print(\"Testing 10 datapoint, k=1\")\n", "print(f\"Predicted labels: {knn.predict(X_test[5:15], k=1)}\")\n", "print(f\"True labels: {y_test[5:15]}\")\n", "print()\n", "print(\"Testing 10 datapoint, k=4\")\n", "print(f\"Predicted labels: {knn.predict(X_test[5:15], k=4)}\")\n", "print(f\"True labels: {y_test[5:15]}\")\n", "print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Accuracy on test set" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2018-03-26T14:33:36.781872Z", "start_time": "2018-03-26T14:33:36.495726Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy with k = 1: 99.11111111111111\n", "Test accuracy with k = 5: 98.66666666666667\n" ] } ], "source": [ "# Compute accuracy on test set\n", "y_p_test1 = knn.predict(X_test, k=1)\n", "test_acc1= np.sum(y_p_test1[0] == y_test)/len(y_p_test1[0]) * 100\n", "print(f\"Test accuracy with k = 1: {format(test_acc1)}\")\n", "\n", "y_p_test5 = knn.predict(X_test, k=5)\n", "test_acc5= np.sum(y_p_test5 == y_test)/len(y_p_test5) * 100\n", "print(f\"Test accuracy with k = 5: {format(test_acc5)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }