{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Name:** \\_\\_\\_\\_\\_\n", "\n", "**EID:** \\_\\_\\_\\_\\_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 2: Logistic Regression and Support Vector Machine\n", "\n", "To predict whether a face image is male or female using logistic regression model, support vector machine, naive bayes, and KNN.\n", "\n", "First we need to initialize Python. Run the below cell." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\HAOYCH~1\\AppData\\Local\\Temp/ipykernel_24896/1552743865.py:4: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n", " IPython.core.display.set_matplotlib_formats(\"svg\")\n" ] } ], "source": [ "%matplotlib inline\n", "import IPython.core.display \n", "# setup output image format (Chrome works best)\n", "IPython.core.display.set_matplotlib_formats(\"svg\")\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from numpy import *\n", "from sklearn import *\n", "import os\n", "import zipfile\n", "import fnmatch\n", "random.seed(100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Loading Data and Pre-processing\n", "We first need to load the images. Download `photos-bw.zip` and put it in the same directory as this ipynb file. **Do not unzip the file.** Then run the following cell to load the images." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loading photos-bw/f-039-01.png\n", "loading photos-bw/f-040-01.png\n", "loading photos-bw/f-041-01.png\n", "loading photos-bw/f-042-01.png\n", "loading photos-bw/f-043-01.png\n", "loading photos-bw/f1-001-0.png\n", "loading photos-bw/f1-002-0.png\n", "loading photos-bw/f1-003-0.png\n", "loading photos-bw/f1-004-0.png\n", "loading photos-bw/f1-005-0.png\n", "loading photos-bw/f1-006-0.png\n", "loading photos-bw/f1-007-0.png\n", "loading photos-bw/f1-008-0.png\n", "loading photos-bw/f1-009-0.png\n", "loading photos-bw/f1-010-0.png\n", "loading photos-bw/f1-011-0.png\n", "loading photos-bw/f1-012-0.png\n", "loading photos-bw/f1-013-0.png\n", "loading photos-bw/f1-014-0.png\n", "loading photos-bw/f1-015-0.png\n", "loading photos-bw/m-063-01.png\n", "loading photos-bw/m-064-01.png\n", "loading photos-bw/m-065-01.png\n", "loading photos-bw/m-067-01.png\n", "loading photos-bw/m-069-01.png\n", "loading photos-bw/m-070-01.png\n", "loading photos-bw/m-073-01.png\n", "loading photos-bw/m-074-01.png\n", "loading photos-bw/m-075-01.png\n", "loading photos-bw/m-077-01.png\n", "loading photos-bw/m-083-01.png\n", "loading photos-bw/m-085-01.png\n", "loading photos-bw/m-091-01.png\n", "loading photos-bw/m-097-01.png\n", "loading photos-bw/m-100-01.png\n", "loading photos-bw/m1-003-0.png\n", "loading photos-bw/m1-004-0.png\n", "loading photos-bw/m1-005-0.png\n", "loading photos-bw/m1-008-0.png\n", "loading photos-bw/m1-010-0.png\n", "loading photos-bw/m1-011-0.png\n", "loading photos-bw/m1-012-0.png\n", "loading photos-bw/m1-013-0.png\n", "loading photos-bw/m1-014-0.png\n", "loading photos-bw/m1-015-0.png\n", "loading photos-bw/m1-016-0.png\n", "loading photos-bw/m1-017-0.png\n", "loading photos-bw/m1-018-0.png\n", "loading photos-bw/m1-035-0.png\n", "loading photos-bw/m1-041-0.png\n", "DONE: loaded 50 images\n" ] } ], "source": [ "imgdata = []\n", "genders = []\n", "\n", "# load the zip file\n", "filename = 'photos-bw.zip'\n", "zfile = zipfile.ZipFile(filename, 'r')\n", "\n", "for name in zfile.namelist():\n", " # check file name matches\n", " if fnmatch.fnmatch(name, \"photos-bw/*.png\"):\n", " print(\"loading\", name)\n", " # open file in memory, and parse as an image\n", " myfile = zfile.open(name)\n", " img = matplotlib.image.imread(myfile)\n", " myfile.close()\n", " \n", " # append to data\n", " imgdata.append(img)\n", " genders.append( int(name[len(\"photos-bw/\")] == 'm') ) # 0 is female, 1 is male\n", " \n", "zfile.close()\n", "imgsize = img.shape\n", "print(\"DONE: loaded {} images\".format(len(imgdata)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each image is a 45x40 array of pixel values. Run the below code to show an example:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(45, 40)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:40.607365\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(img.shape)\n", "plt.imshow(img, cmap='gray', interpolation='nearest')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the below code to show all the images!" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:41.275634\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# function to make an image montage\n", "def image_montage(X, imsize=None, maxw=10):\n", " \"\"\"X can be a list of images, or a matrix of vectorized images.\n", " Specify imsize when X is a matrix.\"\"\"\n", " tmp = []\n", " numimgs = len(X)\n", " \n", " # create a list of images (reshape if necessary)\n", " for i in range(0,numimgs):\n", " if imsize != None:\n", " tmp.append(X[i].reshape(imsize))\n", " else:\n", " tmp.append(X[i])\n", " \n", " # add blanks\n", " if (numimgs > maxw) and (mod(numimgs, maxw) > 0):\n", " leftover = maxw - mod(numimgs, maxw)\n", " meanimg = 0.5*(X[0].max()+X[0].min())\n", " for i in range(0,leftover):\n", " tmp.append(ones(tmp[0].shape)*meanimg)\n", " \n", " # make the montage\n", " tmp2 = []\n", " for i in range(0,len(tmp),maxw):\n", " tmp2.append( hstack(tmp[i:i+maxw]) )\n", " montimg = vstack(tmp2) \n", " return montimg\n", "\n", "plt.figure(figsize=(9,9))\n", "plt.imshow(image_montage(imgdata), cmap='gray', interpolation='nearest')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each image is a 2d array, but the classifier algorithms work on 1d vectors. Run the following code to convert all the images into 1d vectors by flattening. The result should be a matrix where each row is a flattened image." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(50, 1800)\n", "(50,)\n" ] } ], "source": [ "X = empty((50, prod(imgdata[0].shape))) # create empty array\n", "for i,img in enumerate(imgdata):\n", " X[i,:] = ravel(img) # for each image, turn it into a vector\n", "Y = asarray(genders) # convert list to numpy array\n", "print(X.shape)\n", "print(Y.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will shift the pixel values so that gray is 0.0, black is -0.5 and white is 0.5." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: min=0.0, max=1.0\n", "After: min=-0.5, max=0.5\n" ] } ], "source": [ "print(\"Before: min={}, max={}\".format(X.min(), X.max()))\n", "X -= 0.5\n", "print(\"After: min={}, max={}\".format(X.min(), X.max()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, split the dataset into a training set and testing set. We select 80% for training and 20% for testing." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(40, 1800)\n", "(10, 1800)\n" ] } ], "source": [ "# randomly split data into 80% train and 20% test set\n", "trainX, testX, trainY, testY = \\\n", " model_selection.train_test_split(X, Y, \n", " train_size=0.80, test_size=0.20, random_state=4487)\n", "\n", "print(trainX.shape)\n", "print(testX.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Logistic Regression\n", "Train a logistic regression classifier. Use cross-validation to select the best $C$ parameter." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT\n", "# 1. linear_model.LogisticRegressionCV()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C = [0.01274275]\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "logreg = linear_model.LogisticRegressionCV(Cs=logspace(-4,4,20), cv=5, n_jobs=-1)\n", "logreg.fit(trainX, trainY)\n", "\n", "print(\"C = \", logreg.C_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the learned model to predict the genders for the training and testing data. What is the accuracy on the training set? What is the accuracy on the testing set?" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT\n", "# 1. To calculate the accuracy: metrics.accuracy_score(label, pred)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy = 0.9\n", "test accuracy = 0.8\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "# predict from the model\n", "predYtrain = logreg.predict(trainX)\n", "predYtest = logreg.predict(testX)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(trainY, predYtrain)\n", "print(\"train accuracy =\", acc)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(testY, predYtest)\n", "print(\"test accuracy =\", acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Analyzing the Classifier\n", "Run the below code to show the hyperplane parameter $\\mathbf{w}$ as an image. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:48.024135\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# logreg is the learned logistic regression model\n", "wimg = logreg.coef_.reshape(imgsize) # get the w and reshape into an image\n", "mycmap = matplotlib.colors.LinearSegmentedColormap.from_list('mycmap', [\"#0000FF\", \"#FFFFFF\", \"#FF0000\"])\n", "mm = max(wimg.max(), -wimg.min())\n", "plt.imshow(wimg, interpolation='nearest', cmap=mycmap, vmin=-mm, vmax=mm)\n", "plt.colorbar()\n", "plt.title(\"weight image\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Recall that the classifier prediction is based on the sign of the function $f(\\mathbf{x}) = \\mathbf{w}^T\\mathbf{x}+b = \\sum_{i=1}^P w_ix_i + b$. Here each $x_i$ is a pixel in the face image, and $w_i$ is the corresponding weight. Hence, the function is multiplying face image by the weight image, and then summing over all pixels.\n", "\n", "In order for $f(\\mathbf{x})$ to be positive, then the positive values of the weight image (red regions) should match the positive values in the face image (white pixels), and the negative values of the weight image (blue regions) should be matched with negative values in the face image (black pixels).\n", "\n", "Hence, we can have the following interpretation:\n", "\n", "\n", "\n", "\n", "
Classred regions (positive weights)blue regions (negative weights)white regions (weights near 0)
+1 class (male)white pixels in face imageblack pixels in face imageregion not important
-1 class (female)black pixels in face imagewhite pixels in face imageregion not important
\n", " \n", "_Looking at the weight image, what parts of the face image is the classifier looking at to determine the gender? Does it make sense?_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **INSERT YOUR ANSWER HERE**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **INSERT YOUR ANSWER HERE**\n", " - For females, it is looking for long black hair on the side of the face/neck, as well as a small black hair part on the left side. It is also looking at no black hair on the top, i.e., exposed forehead.\n", " - For males, it is looking for hair on top of the head, and no hair around the neck." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's look at the misclassified faces in the test set. Run the below code to show the misclassifed and correctly classified faces." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:50.530433\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# predYtest are the class predictions on the test set.\n", "\n", "# find misclassified test images\n", "inds = where(predYtest != testY) # get indices of misclassified test images\n", "# make a montage\n", "badimgs = image_montage(testX[inds], imsize=imgsize)\n", "\n", "# find correctly classified test images\n", "inds = where(predYtest == testY)\n", "goodimgs = image_montage(testX[inds], imsize=imgsize)\n", " \n", "plt.figure(figsize=(8,4))\n", "plt.subplot(2,1,1)\n", "plt.imshow(badimgs, cmap='gray', interpolation='nearest')\n", "plt.title('misclassified faces')\n", "plt.subplot(2,1,2)\n", "plt.imshow(goodimgs, cmap='gray', interpolation='nearest')\n", "plt.title('correctly classified faces')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Why did the classifier make incorrect predictions on the misclassified faces?_\n", "- **INSERT YOUR ANSWER HERE**\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **INSERT YOUR ANSWER HERE**\n", " - the misclassified faces are females where the hair is pulled back and not on the side of the face. So to the classifier, they look like male faces." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Support Vector Machine\n", "Now train a support vector machine (SVM) on the same training and testing data. Use cross-validation to select the best $C$ parameter." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT \n", "# 1. C: paramgrid = {'C': logspace(-4,4,20)}\n", "# 2. cross-validation: model_selection.GridSearchCV()\n", "# 3. SVM: svm.SVC()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'C': array([1.00000000e-04, 2.63665090e-04, 6.95192796e-04, 1.83298071e-03,\n", " 4.83293024e-03, 1.27427499e-02, 3.35981829e-02, 8.85866790e-02,\n", " 2.33572147e-01, 6.15848211e-01, 1.62377674e+00, 4.28133240e+00,\n", " 1.12883789e+01, 2.97635144e+01, 7.84759970e+01, 2.06913808e+02,\n", " 5.45559478e+02, 1.43844989e+03, 3.79269019e+03, 1.00000000e+04])}\n", "{'C': 0.004832930238571752}\n", "---\n", "mean=0.6250 {'C': 0.0001}\n", "mean=0.6250 {'C': 0.00026366508987303583}\n", "mean=0.6250 {'C': 0.0006951927961775605}\n", "mean=0.8750 {'C': 0.0018329807108324356}\n", "mean=0.9000 {'C': 0.004832930238571752}\n", "mean=0.9000 {'C': 0.012742749857031334}\n", "mean=0.9000 {'C': 0.03359818286283781}\n", "mean=0.9000 {'C': 0.08858667904100823}\n", "mean=0.9000 {'C': 0.23357214690901212}\n", "mean=0.9000 {'C': 0.615848211066026}\n", "mean=0.9000 {'C': 1.623776739188721}\n", "mean=0.9000 {'C': 4.281332398719396}\n", "mean=0.9000 {'C': 11.288378916846883}\n", "mean=0.9000 {'C': 29.763514416313132}\n", "mean=0.9000 {'C': 78.47599703514607}\n", "mean=0.9000 {'C': 206.913808111479}\n", "mean=0.9000 {'C': 545.5594781168514}\n", "mean=0.9000 {'C': 1438.44988828766}\n", "mean=0.9000 {'C': 3792.690190732246}\n", "mean=0.9000 {'C': 10000.0}\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "\n", "# setup the list of parameters to try\n", "paramgrid = {'C': logspace(-4,4,20)}\n", "print(paramgrid)\n", "\n", "clf = model_selection.GridSearchCV(svm.SVC(kernel='linear'), paramgrid, cv=5, n_jobs=-1)\n", "\n", "clf.fit(trainX, trainY);\n", "\n", "print(clf.best_params_)\n", "\n", "print(\"---\")\n", "for m,p in zip(clf.cv_results_['mean_test_score'], clf.cv_results_['params']):\n", " print(\"mean={:.4f} {}\".format(m,p))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calculate the training and test accuracy for the SVM classifier." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT\n", "# 1. To calculate the accuracy: metrics.accuracy_score(label, pred)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy = 0.9\n", "test accuracy = 0.8\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "\n", "# predict from the model\n", "predYtrain = clf.predict(trainX)\n", "predYtest = clf.predict(testX)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(trainY, predYtrain)\n", "print(\"train accuracy =\", acc)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(testY, predYtest)\n", "print(\"test accuracy =\", acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to before, plot an image of the hyperplane parameters $w$, and view the misclassified and correctly classified test images." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:53.712883\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "### INSERT YOUR CODE HERE\n", "# logreg is the learned logistic regression model\n", "wimg = clf.best_estimator_.coef_.reshape(imgsize)\n", "mm = max(wimg.max(), -wimg.min())\n", "plt.imshow(wimg, interpolation='nearest', cmap=mycmap, vmin=-mm, vmax=mm)\n", "plt.colorbar()\n", "plt.title(\"weight image\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's also look at the misclassified faces in the test set. To show the misclassifed and correctly classified faces like above." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:54.698396\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# find misclassified test images\n", "inds = where(predYtest != testY) # get indices of misclassified test images\n", "# make a montage\n", "badimgs = image_montage(testX[inds], imsize=imgsize)\n", "\n", "# find correctly classified test images\n", "inds = where(predYtest == testY)\n", "goodimgs = image_montage(testX[inds], imsize=imgsize)\n", " \n", "plt.figure(figsize=(8,4))\n", "plt.subplot(2,1,1)\n", "plt.imshow(badimgs, cmap='gray', interpolation='nearest')\n", "plt.title('misclassified faces')\n", "plt.subplot(2,1,2)\n", "plt.imshow(goodimgs, cmap='gray', interpolation='nearest')\n", "plt.title('correctly classified faces')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Are there any differences between the $w$ for logistic regressiona and the $w$ for SVM? Is there any interpretation for the differences?_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **INSERT YOUR ANSWER HERE**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **INSERT YOUR ANSWER HERE**\n", " - The SVM w has more negative weights arround the eye brows and nose, which means it also looks for thicker eye brows and bigger noses in males." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Naive Bayes\n", "Next, we train a naive bayes (NB), which is delieved at last lecture, on the same training and testing data. Then, showing the accuracy of training and testing set." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT\n", "# 1. naive_bayes.GaussianNB()\n", "# 2. metrics.accuracy_score()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy = 0.975\n", "test accuracy = 0.9\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "# train NB model\n", "bayesreg = naive_bayes.GaussianNB().fit(trainX,trainY)\n", "### INSERT YOUR CODE HERE\n", "\n", "# predict from the model\n", "predYtrain = bayesreg.predict(trainX)\n", "predYtest = bayesreg.predict(testX)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(trainY, predYtrain)\n", "print(\"train accuracy =\", acc)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(testY, predYtest)\n", "print(\"test accuracy =\", acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's also look at the misclassified faces in the test set. To show the misclassifed and correctly classified faces like above." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:54:58.964309\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# predYtest are the class predictions on the test set.\n", "\n", "# find misclassified test images\n", "inds = where(predYtest != testY) # get indices of misclassified test images\n", "# make a montage\n", "badimgs = image_montage(testX[inds], imsize=imgsize)\n", "\n", "# find correctly classified test images\n", "inds = where(predYtest == testY)\n", "goodimgs = image_montage(testX[inds], imsize=imgsize)\n", " \n", "plt.figure(figsize=(8,4))\n", "plt.subplot(2,1,1)\n", "plt.imshow(badimgs, cmap='gray', interpolation='nearest')\n", "plt.title('misclassified faces')\n", "plt.subplot(2,1,2)\n", "plt.imshow(goodimgs, cmap='gray', interpolation='nearest')\n", "plt.title('correctly classified faces')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. K Nearest Neighbor\n", "In addition, to train a K Nearest Neighbor (KNN) model, which is delieved at last lecture, on the same training and testing data. Using cross-validation to select the best K parameter. Then, showing the accuracy of training and testing set." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "### INSERT YOUR CODE HERE\n", "## HINT \n", "# 1. K: paramgrid = {'n_neighbors': [3,5,10]}\n", "# 2. cross-validation: model_selection.GridSearchCV()\n", "# 3. neighbors.KNeighborsClassifier()\n", "# 4. metrics.accuracy_score()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'n_neighbors': [3, 5, 10]}\n" ] } ], "source": [ "### INSERT YOUR CODE HERE\n", "# setup the list of parameters to try\n", "paramgrid = {'n_neighbors': [3,5,10]}\n", "print(paramgrid)\n", "clf = model_selection.GridSearchCV(neighbors.KNeighborsClassifier(), param_grid=paramgrid, cv=5, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy = 0.875\n", "test accuracy = 0.8\n" ] } ], "source": [ "clf.fit(trainX, trainY)\n", "# predict from the model\n", "predYtrain = clf.predict(trainX)\n", "predYtest = clf.predict(testX)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(trainY, predYtrain)\n", "print(\"train accuracy =\", acc)\n", "\n", "# calculate accuracy\n", "acc = metrics.accuracy_score(testY, predYtest)\n", "print(\"test accuracy =\", acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's also look at the misclassified faces in the test set. To show the misclassifed and correctly classified faces like above." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-09-26T22:55:02.777022\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# predYtest are the class predictions on the test set.\n", "\n", "# find misclassified test images\n", "inds = where(predYtest != testY) # get indices of misclassified test images\n", "# make a montage\n", "badimgs = image_montage(testX[inds], imsize=imgsize)\n", "\n", "# find correctly classified test images\n", "inds = where(predYtest == testY)\n", "goodimgs = image_montage(testX[inds], imsize=imgsize)\n", " \n", "plt.figure(figsize=(8,4))\n", "plt.subplot(2,1,1)\n", "plt.imshow(badimgs, cmap='gray', interpolation='nearest')\n", "plt.title('misclassified faces')\n", "plt.subplot(2,1,2)\n", "plt.imshow(goodimgs, cmap='gray', interpolation='nearest')\n", "plt.title('correctly classified faces')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 }