{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# K邻近算法" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "欧式距离;曼哈顿距离;闵可夫斯基距离" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "交叉验证:在测试集上错误率最小,在训练集上错误率不一定最小" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "KNN困境:维度增加,距离失效;数据量大,算法超慢" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一个机器学习程序需要多少数据训练:首先需要知道维度和特征的信息,维度和特征决定训练集的量" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "算法流程:\n", "对每一个未知点执行:\n", "=>计算未知点到所有已知类别点的距离\n", "=》按距离排序(升序)\n", "=》选取其中前k个与未知点离得最近的点\n", "=》统计k个点中各个类别的个数\n", "=》上述k个点里类别出现频率最高的作为未知点的类别" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "优点: \n", "简单有效、易理解" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "缺点: \n", "k近邻需要保存全部数据集,因此对内存消耗大,当数据集较大时对设备要求非常高; \n", "需要计算每个未知点到全部已知点的距离,可能会很耗时; \n", "分类结果不易理解" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 预测单个测试data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "录入数据集(此处为人模拟的数据集),X第一列为肿瘤大小,第二列为肿瘤时间,y为肿瘤性质良好还是恶性" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "raw_data_X = [[3.393533211, 2.331273381],\n", " [3.110073273, 1.786360121],\n", " [1.343892307, 3.362874429],\n", " [3.580243273, 4.671037091],\n", " [2.274392744, 2.873335573],\n", " [7.474390402, 4.673011339],\n", " [5.772024290, 3.560262131],\n", " [9.122354845, 2.568264233],\n", " [7.722344298, 3.479979792],\n", " [7.978408784, 0.773246244]\n", " ]\n", "raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[3.39353321 2.33127338]\n", " [3.11007327 1.78636012]\n", " [1.34389231 3.36287443]\n", " [3.58024327 4.67103709]\n", " [2.27439274 2.87333557]\n", " [7.4743904 4.67301134]\n", " [5.77202429 3.56026213]\n", " [9.12235485 2.56826423]\n", " [7.7223443 3.47997979]\n", " [7.97840878 0.77324624]]\n", "[0 0 0 0 0 1 1 1 1 1]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#转为向量传入训练集\n", "X_train = np.array(raw_data_X)\n", "y_train = np.array(raw_data_y)\n", "print (X_train)\n", "print (y_train)\n", "\n", "#绘制散点图\n", "plt.scatter(X_train[y_train == 0,0], X_train[y_train == 0,1], c='g', marker='o', label='0')\n", "plt.scatter(X_train[y_train == 1,0], X_train[y_train == 1,1], c='r', marker='x', label='1')\n", "plt.xlabel('size')\n", "plt.ylabel('time')\n", "plt.legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#要预测的点\n", "x = np.array([5, 3])\n", "\n", "#图中观测\n", "plt.scatter(X_train[y_train == 0,0], X_train[y_train == 0,1], c='g', marker='o', label='良性')\n", "plt.scatter(X_train[y_train == 1,0], X_train[y_train == 1,1], c='r', marker='x', label='恶性')\n", "plt.scatter(x[0], x[1])\n", "plt.xlabel('size')\n", "plt.ylabel('time')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## KNN过程" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.740095064966033, 2.246050932042296, 3.6740714900551215, 2.1927321139070988, 2.728548843438044, 2.986900534321874, 0.9538947320237519, 4.144901113488964, 2.764333387560458, 3.718783561121629]\n" ] } ], "source": [ "#求距离\n", "from math import sqrt\n", "\n", "distance=[]\n", "\n", "for x_train in X_train:\n", " d = sqrt(np.sum((x_train - x)**2))\n", " distance.append(d)\n", "print (distance)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([6, 0, 3, 1, 4, 8, 5, 2, 9, 7], dtype=int64)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#从距离数组中获得最近的k个点,如果将X排序,与y则不对应,此处想获得的是索引,所以可以用argsort方法进行排序获得其索引找到最近的k个点在哪\n", "\n", "nearest = np.argsort(distance)\n", "nearest" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1, 0, 0, 0, 0, 1]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#令k = 6\n", "k = 6\n", "topK_k = [y_train[i] for i in nearest[:6]]\n", "topK_k" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Counter({0: 4, 1: 2})\n", "[(0, 4)]\n", "[(0, 4), (1, 2)]\n" ] } ], "source": [ "#求投票结果(每种结果各有多少个,比例)\n", "\n", "from collections import Counter\n", "\n", "Counter(topK_k) #获取的结果为字典,键为原数组中的各个值,相应的值为出现的次数\n", "\n", "votes = Counter(topK_k)\n", "print (votes)\n", "\n", "votes.most_common(1) #找票数最多的点,即值最大的一个或几个(参数传入几,即寻找最大的前几个点)\n", "print (votes.most_common(1))\n", "votes.most_common(2)\n", "print (votes.most_common(2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predict_y = votes.most_common(1)[0][0]\n", "print (predict_y)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }