{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN算法实现" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 读取数据 image label" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from numpy import *\n", "import operator\n", "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", "from os import listdir\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import struct\n", "\n", "#读取图片\n", "def read_image(file_name):\n", " #先用二进制方式把文件都读进来\n", " file_handle=open(file_name,\"rb\") #以二进制打开文档\n", " file_content=file_handle.read() #读取到缓冲区中\n", "\n", " offset=0\n", " head = struct.unpack_from('>IIII', file_content, offset) # 取前4个整数,返回一个元组\n", " offset += struct.calcsize('>IIII')\n", " imgNum = head[1] #图片数\n", " rows = head[2] #宽度\n", " cols = head[3] #高度\n", " # print(imgNum)\n", " # print(rows)\n", " # print(cols)\n", "\n", " #测试读取一个图片是否读取成功\n", " #im = struct.unpack_from('>784B', file_content, offset)\n", " #offset += struct.calcsize('>784B')\n", "\n", " images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法\n", " image_size=rows*cols#单个图片的大小\n", " fmt='>' + str(image_size) + 'B'#单个图片的format\n", "\n", " for i in range(imgNum):\n", " images[i] = np.array(struct.unpack_from(fmt, file_content, offset))\n", " # images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))\n", " offset += struct.calcsize(fmt)\n", "\n", " return images\n", "\n", " '''bits = imgNum * rows * cols # data一共有60000*28*28个像素值\n", " bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'\n", " imgs = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组\n", " imgs_array=np.array(imgs).reshape((imgNum,rows*cols)) #最后将读取的数据reshape成 【图片数,图片像素】二维数组\n", " return imgs_array'''\n", "\n", "def read_label(file_name):\n", " file_handle = open(file_name, \"rb\") # 以二进制打开文档\n", " file_content = file_handle.read() # 读取到缓冲区中\n", "\n", " head = struct.unpack_from('>II', file_content, 0) # 取前2个整数,返回一个元组\n", " offset = struct.calcsize('>II')\n", "\n", " labelNum = head[1] # label数\n", " # print(labelNum)\n", " bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'\n", " label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组\n", " return np.array(label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN Train" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# KNN算法\n", "def KNN(test_data, images, labels, k):\n", " # images.shape[0]表示的是读取矩阵第一维度的长度,代表行数60k\n", " dataSetSize = images.shape[0] # 60k\n", " # tile函数在行上重复dataSetSize次,在列上重复1次\n", " # 将test_data乘60k。再reshape成784行*60000列\n", " distance1 = tile(test_data, (dataSetSize)).reshape((dataSetSize,784))-images\n", " # 每个元素平方\n", " distance2 = distance1**2\n", " # 矩阵每行相加\n", " distance3 = distance2.sum(axis=1) # 合并成一行\n", " # 开方\n", " distances4 = distance3**0.5\n", " # 欧氏距离计算结束\n", " # 返回从小到大排序的索引\n", " sortedDistIndicies = distances4.argsort()\n", " classCount=np.zeros((10), np.int32) #10是代表10个类别\n", " for i in range(k): # 统计排在前k名的预测结果\n", " ## 找到图片对应的数据\n", " voteIlabel = labels[sortedDistIndicies[i]] \n", " ## 把这个Count + 1\n", " classCount[voteIlabel] += 1\n", "\n", " return np.argmax(classCount),sortedDistIndicies[0:k]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN TEST(用于绘制misclassification rate)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def KNN_TEST(train_x,test_x,train_y,test_y,k): \n", " testNum = test_x.shape[0]\n", " print('测试图片数量:',testNum)\n", " errorCount = 0 # 判断错误的个数\n", " #error = []\n", " #error_image = []\n", " errorCount = 0 # 判断错误的个数\n", " for i in range(testNum):\n", " result,sortlist = KNN(test_x[i], train_x, train_y, k)\n", " # print('返回的结果是: %s, 真实结果是: %s' % (result, test_y[i]))\n", " if result != test_y[i]:\n", " errorCount += 1.0 # 如果mnist验证集的标签和本身标签不一样,则出错\n", " ##print('返回的结果是: %s, 真实结果是: %s' % (result, test_y[i]))\n", " ## 为了更加直观,只打印出错误的预测结果\n", " #error.append(test_y[i])\n", " #error_image.append(sortlist)\n", " error_rate = errorCount / float(testNum) # 计算出错率\n", " acc = 1.0 - error_rate\n", " #print(\"\\n正确预测率: %f\" % (acc))\n", " #print(\"\\n错误预测率: %f\" % (error_rate))\n", " #print(\"\\n错误预测的图片数量: %d\" % errorCount)\n", " #return error,error_image\n", " return error_rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN TEST(传入单张图片,输出K个最相似的图片)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def KNN_TEST_PRINT(train_x,test_x,train_y,test_y,test_id,k):\n", " result,sortlist = KNN(test_x[test_id], train_x, train_y, k)\n", " print('预测结果为:',result,'实际结果为:',test_y[test_id])\n", " #sortlist.append(test_id)\n", " ## 打印图片\n", " fig=plt.figure(figsize=(8,8))\n", " fig.subplots_adjust(left=0,right=1,bottom=0,top=1,hspace=0.05,wspace=0.05)\n", " for i in range(len(sortlist)):\n", " images = np.reshape(train_x[sortlist[i]], [28,28])\n", " ax=fig.add_subplot(6,5,i+1,xticks=[],yticks=[])\n", " ax.imshow(images,cmap=plt.cm.binary,interpolation='nearest')\n", " ax.text(0,7,str(train_y[sortlist[i]]))\n", " plt.show()\n", " plt.pause(5)##停留5秒\n", " plt.close(fig)##关闭图片" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 传入数据" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "train_image = \"MNIST/train-images-idx3-ubyte\"\n", "test_image = \"MNIST/t10k-images-idx3-ubyte\"\n", "train_label = \"MNIST/train-labels-idx1-ubyte\"\n", "test_label = \"MNIST/t10k-labels-idx1-ubyte\"\n", "# 读取数据\n", "train_x = read_image(train_image) # train_image\n", "test_x = read_image(test_image) # test_image\n", "train_y = read_label(train_label) # train_label\n", "test_y = read_label(test_label) # test_label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 设置运行参数" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# 训练图片数\n", "trainNum = 6000\n", "# 测试图片数\n", "testNum = 100\n", "# 距离最小的k个图片\n", "k = 30\n", "\n", "train_Start = random.randint(0,60001-trainNum)\n", "train_End = train_Start+trainNum\n", "train_i = train_x[train_Start:train_End,:]\n", "train_l = train_y[train_Start:train_End]\n", "\n", "test_Start = random.randint(0,10001-testNum)\n", "## test_Start = 0\n", "test_End = test_Start+testNum\n", "test_i = test_x[test_Start:test_End,:]\n", "test_l = test_y[test_Start:test_End]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 运行" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": "预测结果为: 9 实际结果为: 9\n" }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "image/png": "\n" }, "metadata": {} } ], "source": [ "## 打印输入图片的k近邻图片\n", "import matplotlib\n", "matplotlib.use('TkAgg') ##macOS需要换一下后端才能显示图片。\n", "from matplotlib import pyplot as plt\n", "%matplotlib inline\n", "## 输入图片的编号(随机产生)\n", "test_id = random.randint(0,10000)\n", "KNN_TEST_PRINT(train_i,test_x,train_l,test_y,test_id,k)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "No. 1\n", "训练集图片数量: 6000 区间: 49303 ~ 55303\n", "测试集图片数量: 500 区间: 9438 ~ 9938\n", "测试图片数量: 500\n", "K= 1 accuracy: 0.918\n", "\n", "No. 2\n", "训练集图片数量: 6000 区间: 38563 ~ 44563\n", "测试集图片数量: 500 区间: 4042 ~ 4542\n", "测试图片数量: 500\n", "K= 2 accuracy: 0.868\n", "\n", "No. 3\n", "训练集图片数量: 6000 区间: 21625 ~ 27625\n", "测试集图片数量: 500 区间: 8959 ~ 9459\n", "测试图片数量: 500\n", "K= 3 accuracy: 0.982\n", "\n", "No. 4\n", "训练集图片数量: 6000 区间: 39939 ~ 45939\n", "测试集图片数量: 500 区间: 3489 ~ 3989\n", "测试图片数量: 500\n", "K= 4 accuracy: 0.92\n", "\n", "No. 5\n", "训练集图片数量: 6000 区间: 21587 ~ 27587\n", "测试集图片数量: 500 区间: 257 ~ 757\n", "测试图片数量: 500\n", "K= 5 accuracy: 0.906\n", "\n", "No. 6\n", "训练集图片数量: 6000 区间: 28045 ~ 34045\n", "测试集图片数量: 500 区间: 2604 ~ 3104\n", "测试图片数量: 500\n", "K= 6 accuracy: 0.9359999999999999\n", "\n", "No. 7\n", "训练集图片数量: 6000 区间: 43593 ~ 49593\n", "测试集图片数量: 500 区间: 6955 ~ 7455\n", "测试图片数量: 500\n", "K= 7 accuracy: 0.974\n", "\n", "No. 8\n", "训练集图片数量: 6000 区间: 9651 ~ 15651\n", "测试集图片数量: 500 区间: 9481 ~ 9981\n", "测试图片数量: 500\n", "K= 8 accuracy: 0.908\n", "\n", "No. 9\n", "训练集图片数量: 6000 区间: 34086 ~ 40086\n", "测试集图片数量: 500 区间: 8445 ~ 8945\n", "测试图片数量: 500\n", "K= 9 accuracy: 0.99\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "## 打印misclassification rate曲线\n", "#import datetime\n", "#start = datetime.datetime.now()\n", "import matplotlib\n", "matplotlib.use('TkAgg') ##macOS需要换一下后端才能显示图片\n", "import numpy as np\n", "import math\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "# 训练图片数\n", "trainNum = 6000\n", "# 测试图片数\n", "testNum = 500\n", "\n", "x = np.arange(1, 10, 1)\n", "y = []\n", "for t in x:\n", " ##数据选择\n", " train_Start = random.randint(0,60001-trainNum)\n", " train_End = train_Start+trainNum\n", " train_i = train_x[train_Start:train_End,:]\n", " train_l = train_y[train_Start:train_End]\n", "\n", " test_Start = random.randint(0,10001-testNum)\n", " ## test_Start = 0\n", " test_End = test_Start+testNum\n", " test_i = test_x[test_Start:test_End,:]\n", " test_l = test_y[test_Start:test_End]\n", " \n", " print('\\nNo.',t)\n", " \n", " print('训练集图片数量:',train_i.shape[0],'区间:',train_Start,'~',train_End)\n", " print('测试集图片数量:',test_i.shape[0],'区间:',test_Start,'~',test_End)\n", " \n", " misclassification_rate = KNN_TEST(train_i,test_i,train_l,test_l,t)\n", " print('K=',t,'accuracy:',1-misclassification_rate)\n", " y.append(misclassification_rate)\n", "plt.plot(x, y, label='KNN')\n", "plt.xlabel(\"K\")\n", "plt.ylabel(\"misclassification rate\")\n", "plt.ylim(0,2*max(y))\n", "plt.legend()\n", "plt.show()\n", "#error,error_image=KNN_TEST(train_i,test_i,train_l,test_l,k)\n", "#b=[i+testNum for i in error]\n", "\n", "#for i in range(len(b)):\n", "# print(error[i],error_image[i])\n", "\n", "#end = datetime.datetime.now()\n", "#print ('程序运行时间:',end-start)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.3 64-bit ('base': conda)", "language": "python", "name": "python37364bitbaseconda12cedf669173426cb0cfb548cfc60e96" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3-final" } }, "nbformat": 4, "nbformat_minor": 2 }