{ "cells": [ { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 90.83%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "D:\\anaconda3\\lib\\site-packages\\hypertools\\plot\\plot.py:508: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", " kwargs[kwarg]=np.array(kwargs[kwarg])\n" ] }, { "data": { "text/plain": "
", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "
", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#!/usr/bin/env python\n", "# -*- coding: UTF-8 -*-\n", "'''\n", "@Project :PCA\n", "@File :main.py\n", "@IDE :PyCharm\n", "@Author :Johnson Yan\n", "@Date :2022/12/11 18:20\n", "'''\n", "import random\n", "\n", "import cv2\n", "from glob import glob\n", "import os\n", "import hypertools as hyp\n", "\n", "import numpy as np\n", "from sklearn.decomposition import PCA\n", "from sklearn.svm import SVC\n", "\n", "\n", "def orl_dataset(image_path='orl_faces/'):\n", " train_data = []\n", " test_data = []\n", " train_label = []\n", " test_label = []\n", " folder_num = len(os.listdir(image_path))\n", " for k in range(folder_num):\n", " folder = os.path.join(image_path, 's%d' % (k + 1))\n", " faces = len(os.listdir(folder))\n", " data = [cv2.imread(d, 0) for d in glob(os.path.join(folder, '*.pgm'))]\n", "\n", " # split train dataset and test dataset\n", " random_list = random.sample(range(0, 10), int(faces * 0.7))\n", " train_data.extend([data[i].ravel() for i in range(10) if i in random_list])\n", " test_data.extend([data[i].ravel() for i in range(10) if i not in random_list])\n", " train_label.extend([k for _ in range(int(faces * 0.7))])\n", " test_label.extend([k for _ in range(int(faces * 0.3))])\n", "\n", " return np.array(train_data), np.array(train_label), np.array(test_data), np.array(test_label)\n", "\n", "\n", "def pca(data, k):\n", " data = np.float32(np.mat(data))\n", " rows, cols = data.shape\n", " data_mean = np.mean(data, 0)\n", " Z = data - np.tile(data_mean, (rows, 1))\n", " D, V = np.linalg.eig(Z * Z.T)\n", " V1 = V[:, :k]\n", " V1 = Z.T * V1\n", " for i in range(k):\n", " V1[:, i] /= np.linalg.norm(V1[:, i])\n", " return np.array(Z * V1), data_mean, V1\n", "\n", "\n", "def main():\n", " train_data, train_labels, test_data, test_labels = orl_dataset()\n", " hyp.plot(train_data, 'o')\n", " pca1 = PCA(n_components=0.7)\n", " train_data_pca = pca1.fit_transform(train_data) # 把原始训练集映射到主成分组成的子空间中\n", " test_data_pca = pca1.transform(test_data) # 把原始测试集映射到主成分组成的子空间中\n", " hyp.plot(train_data_pca, 'o', n_clusters=10)\n", " clf = SVC(C=1000.0, cache_size=200, class_weight='balanced', coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma=0.001, kernel='poly',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)\n", " # hyp.plot(train_data_pca, 'o')\n", " clf.fit(train_data_pca / 255, train_labels)\n", " predict = clf.predict(test_data_pca / 255)\n", " # hyp.plot(predict.reshape(200,1),'o', n_clusters=40)\n", " # hyp.plot(test_labels.reshape(200,1), 'o', n_clusters=40)\n", " print('Accuracy: %.2f%%' % ((predict == np.array(test_labels)).mean() * 100))\n", "\n", "\n", "if __name__ == '__main__':\n", " main()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }