{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'\\nCLASS: Model evaluation procedures\\n'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'''\n", "CLASS: Model evaluation procedures\n", "'''" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sinanozdemir/anaconda/envs/sfdat26-env/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n", " warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# read in the iris data\n", "from sklearn.datasets import load_iris\n", "iris = load_iris()\n", "X, y = iris.data, iris.target" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "## TRAIN AND TEST ON THE SAME DATA (OVERFITTING)\n", "\n", "from sklearn.neighbors import KNeighborsClassifier\n", "knn = KNeighborsClassifier(n_neighbors=1)\n", "knn.fit(X, y)\n", "knn.score(X, y)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "## TEST SET APPROACH\n", "\n", "# understanding train_test_split\n", "from sklearn.cross_validation import train_test_split" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0, 10],\n", " [ 1, 11],\n", " [ 2, 12],\n", " [ 3, 13],\n", " [ 4, 14],\n", " [ 5, 15],\n", " [ 6, 16],\n", " [ 7, 17],\n", " [ 8, 18],\n", " [ 9, 19]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features = np.array([range(10), range(10, 20)]).T\n", "features # 2D array" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['even', 'odd', 'even', 'odd', 'even', 'odd', 'even', 'odd', 'even', 'odd']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response = ['even', 'odd'] * 5\n", "response # 1D array\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# step 1: split data into training set and test set\n", "features_train, features_test, response_train, response_test \\\n", "= train_test_split(features, response, random_state=4)\n", "# the random_state allows us all to get the same random numbers" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 9, 19],\n", " [ 2, 12],\n", " [ 6, 16],\n", " [ 0, 10],\n", " [ 1, 11],\n", " [ 5, 15],\n", " [ 7, 17]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features_train # 70% of the training set" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 3, 13],\n", " [ 8, 18],\n", " [ 4, 14]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features_test # remaining 30% of the training set" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['odd', 'even', 'even', 'even', 'odd', 'odd', 'odd']" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response_train # 70% of the responses, the SAME ones as features_train" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "['odd', 'even', 'even']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response_test # reamining 30%, SAME as features_test" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# steps 2 and 3: calculate test set error for K=1\n", "knn = KNeighborsClassifier(n_neighbors=1)\n", "knn.fit(features_train, response_train) # Note that I fit to the training\n", "knn.score(features_test, response_test) # and scored on the test set" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.66666666666666663" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Suppose you watch a soccer game and memorize their patterns\n", "# If you rewind the game and then were asked to predict\n", "# the outcome, you'd know right?! That's why we train on one game\n", "# and then get tested on our predictive ability on another game\n", "\n", "# Overfitting\n", "\n", "# step 4 (parameter tuning): calculate test set error for K=3\n", "knn = KNeighborsClassifier(n_neighbors=3)\n", "knn.fit(features_train, response_train)\n", "knn.score(features_test, response_test)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n", " weights='uniform')" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# steps 5 and 6: choose best model (K=3) and train on all data\n", "knn = KNeighborsClassifier(n_neighbors=5)\n", "knn.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([1, 0])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# There are two types of data we will deal with in ML\n", "# In sample and Out of sample data\n", "# the in-sample data consists of our training and test sets\n", "# Note we know the answer to all of them!\n", "# the out-of-sample data are data that we really haven't seen before\n", "# They're the reason we built it in the first time!\n", "\n", "# step 7: make predictions on new (\"out of sample\") data\n", "out_of_sample = [[5, 4, 3, 2], [4, 3, 2, 1]]\n", "knn.predict(out_of_sample)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# verify that a different train/test split can result in a different test set error\n", "features_train, features_test, response_train, response_test \\\n", "= train_test_split(X, y, random_state=1)\n", "# I used a different random state so we all get the same results\n", "# but different from the random_state = 4 from before!" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn = KNeighborsClassifier(n_neighbors=3)\n", "knn.fit(features_train, response_train)\n", "knn.score(features_test, response_test)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Using a training set and test set is so important\n", "# Just as important is cross validation. Cross validation is\n", "# Just using several different train test splits and \n", "# averaging your results!\n", "\n", "## CROSS-VALIDATION\n", "\n", "# check CV score for K=1\n", "from sklearn.cross_validation import cross_val_score\n", "knn = KNeighborsClassifier(n_neighbors=1)\n", "scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([ 0.96666667, 0.96666667, 0.93333333, 0.93333333, 1. ])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores # It ran a KNN 5 times!\n", "# We are looking at the accuracy for each of the 5 splits" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.95999999999999996" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(scores) # Average them together" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.97333333333333338" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check CV score for K=5\n", "knn = KNeighborsClassifier(n_neighbors=5)\n", "scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')\n", "scores\n", "np.mean(scores)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[0.95999999999999996,\n", " 0.96666666666666679,\n", " 0.97333333333333338,\n", " 0.98000000000000009,\n", " 0.97333333333333338,\n", " 0.98000000000000009,\n", " 0.97333333333333338,\n", " 0.96666666666666679,\n", " 0.96666666666666679,\n", " 0.96666666666666679,\n", " 0.96666666666666679,\n", " 0.95999999999999996,\n", " 0.95999999999999996,\n", " 0.94666666666666666,\n", " 0.93333333333333324]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# search for an optimal value of K\n", "k_range = range(1, 30, 2)\n", "scores = []\n", "for k in k_range:\n", " knn = KNeighborsClassifier(n_neighbors=k)\n", " scores.append(np.mean(cross_val_score(knn, X, y, cv=5, scoring='accuracy')))\n", "scores" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEACAYAAABfxaZOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAG/lJREFUeJzt3X+UHHWZ7/H3J4QoKCSgMSjh540CybJR1o3sJYFZ4MKA\nCBhZSQQWhY14BBPdkYX1HEkOu94FdkXDKiJsxOCK4Ye5iitqsidM+HFuIJAxhpgEDtGQCKIXV37I\nCvnx3D++1aEZZjLdMzVd3V2f1zlz0l1V3f0UxTz9nW/VU48iAjMzK4cRRQdgZmaN46RvZlYiTvpm\nZiXipG9mViJO+mZmJeKkb2ZWIjUlfUmdktZLekzSZX2sHyNpsaTVklZImli1bo6kNdnP7DyDNzOz\n+gyY9CWNAL4CnAxMAmZKOrzXZp8DeiJiMnA+cF322knAhcB7gXcDp0k6NL/wzcysHrWM9KcAj0fE\npojYCiwCzui1zURgGUBEbAAOljQWOAJ4MCJejojtwL3A9NyiNzOzutSS9PcHNlc935Itq7aaLJlL\nmgIcCIwHHgWmSdpH0p7AqcABQw3azMwGZ2RO73MVMF/SKmAN0ANsj4j1kq4GlgIvVpbn9JlmZlan\nWpL+r0gj94rx2bKdIuIF4ILKc0m/ADZm624Gbs6Wf4HX/tVA1Wt8EyAzszpFhOrZvpbpnZXABEkH\nSRoFzADuqt5A0mhJu2ePZwHLI+LF7PnY7N8DgQ8Ct+4i+Lb8mTt3buExeP+8f96/9vsZjAFH+hGx\nXdIlwBLSl8SCiFgn6aK0Om4knbBdKGkHsJZ0xU7FdyXtC2wFPhkRzw8qUjMzG7Ka5vQj4sfAYb2W\nfb3q8Yre66vWHTuUAM3MLD+uyG2Ajo6OokMYVt6/1ub9KxcNdl4ob5KiWWIxM2sFkohhOJFrZmZt\nwknfzKxEnPTNzErESd/MrESc9M3MSsRJ38ysRJz0zcxKxEnfzKxEnPTNzErESd/MrESc9M3MSsRJ\n38ysRJz0zcxKxEnfzKxEnPTNzErESd/MrESc9M3MSsRJ38ysRJz0zcxKxEnfzKxEnPTNzErESd/M\nrESc9M3MSsRJ38ysRJz0zcxKxEnfzKxEakr6kjolrZf0mKTL+lg/RtJiSaslrZA0sWrdZyQ9Kuln\nkr4taVSeO2BmZrUbMOlLGgF8BTgZmATMlHR4r80+B/RExGTgfOC67LXvAD4FHBURfwqMBGbkF76Z\nmdWjlpH+FODxiNgUEVuBRcAZvbaZCCwDiIgNwMGSxmbrdgPeJGkksCfwVC6Rm5lZ3WpJ+vsDm6ue\nb8mWVVsNTAeQNAU4EBgfEU8BXwSeBH4F/D4i/nOoQZuZ2eDkdSL3KmAfSauAi4EeYLukMaS/Cg4C\n3gG8WdJHcvrMUtu2Df7u72Dt2qIjqd2GDdDVBa+8UnQkZuU1soZtfkUauVeMz5btFBEvABdUnkva\nCGwEOoGNEfG7bPli4H8Ct/b1QfPmzdv5uKOjg46OjhrCK58I+MQn4MEHYdEieOABOOCAoqPataee\ngs5O2Gsv+M1vYOFCGOFrx8zq0t3dTXd395DeQxGx6w2k3YANwAnA08BDwMyIWFe1zWjgpYjYKmkW\ncExEfDSb6lkA/DnwMnAzsDIivtrH58RAsVjy+c/Dj38M99wDX/86fOMbcN99sO++RUfWt+eeg+OO\ngw9/GD79aTjxRJg6Fa65pujIzFqbJCJC9bxmwJF+RGyXdAmwhDQdtCAi1km6KK2OG4EjgIWSdgBr\ngQuz1z4k6U7SdM/W7N8b6wnQXuv66+G22+D+++HNb07TJU89BaefDkuXwh57FB3ha738Mpx5Zkry\nf//3IMEPfgDTpsHb3w6f+UzREZqVy4Aj/UbxSH9gixfDpz6VRvWHHvrq8h074Lzz4A9/gDvvhJG1\nTNo1wI4dMGNGmo5atAh22+3VdU8+CcccA//8z2kbM6vfYEb6TvotYvly+Ku/gp/8BN7zntevf+UV\nOO00OOQQuOGGNKIuUgTMmQM/+1mainrjG1+/zZo1cMIJcOutacrHzOozmKTvU2ktYM2alPC/852+\nEz7AqFHw3e/Cww/DlVc2Nr6+XH11+qL6/vf7TvgARx6Z/jL5yEegp6ex8ZmVlZN+k9u0CU49Fa67\nLo2Kd2WvveDuu+Fb30oneIvyzW+mvzZ+9CMYPXrX2x57bNr2tNNg48aGhGdWak0y+2t9efbZdJlj\nV1ft897jxqXplGOPTY/PPHN4Y+zt7rvh8suhuxve8Y7aXjN9OjzzDJx8crr89G1vG9YQzUrNc/pN\n6qWX0jz3tGlpqqRejzwCp5ySTv5OnZp/fH158ME0Yv/BD+Doo+t//RVXpL8O7rknXZlkZrvmE7lt\nYts2+OAHYZ990lTJYIuYlixJV/UsWwaTJuUa4uts2JCuxV+wAN7//sG9RwTMmgWbN6cvjlG+H6vZ\nLvlEbhuoVNtu3ZoS6FCqVk86Ca69No34N28eePvBqlTb/tM/DT7hQ7ri6IYb4A1vgAsvTJd8mlm+\nnPSbzBVXwOrV6aqW3Xcf+vudc066dLKzE373u6G/X2/PPZdONM+aBR/72NDfb+TIdE3/E0+kcwNm\nli9P7zSR66+HL385VdvmfTKzqyvNuedZtfvyy+nLZNIk+Nd/zbc24Nln07mIj3/cVbtm/fGcfgvr\nr9o2L3lX7W7fDjNn9l1tm5dK1e4116TPMrPXctJvUQNV2+Ylr6rdCJg9OxWN9VdtmxdX7Zr1zydy\nW1At1bZ5yatq9+qr4d57d11tm5cjj4Q77nDVrllenPQLVE+1bV6GWrX7zW+m19VSbZuX446Dr33N\nVbtmeXBFbkEq1baf/Wzj7zI5blyaSpo2rb6q3Uq17fLltVfb5uVDH0rNV1y1azY0ntMvwFCrbfNS\nT9XuUKtt81LdQMZVu1Z2PpHbAvKqts1LLVW7eVTb5sVVu2av8oncJpdntW1eBqrazavaNi+Vqt1R\no1y1azYYTZB2yuOKK1JTkbyqbfPSX9Vu3tW2eRk5MrWMdNWuWf08vdMglWrbBx6AsWOLjqZvn/0s\nrFiRqnZHjEhfAn/yJ+nqoqI7cfXFVbtWdp7Tb1LDXW2bl+qq3VGjhrfaNi+u2rUyc9JvQvfeC2ed\nNfzVtnl55RX4wAfSvz/60fAXX+XBVbtWVk76TaaSjL7zncYVX+Vh+/Y0yh/q/XkaafnyV79cjzqq\n6GjMGsNX7zSRJ59sfLVtXnbbrbUSPqRLSiu9dp94ouhozJpXi/1qt4Znn02Vo0VU25ZZpWq3s9NV\nu2b98fROzpql2rbMXLVrZeE5/YI1W7VtWblq18rCSb9AlUSzZUtKNM1UfFVGlS/gMWNg4UJ/AVt7\n8oncAs2d25zVtmXlql2zvvlEbg6+9rVUxPTAA55DbiZ77pn+6po6Fd7+dlftmkGNI31JnZLWS3pM\n0mV9rB8jabGk1ZJWSJqYLX+XpB5Jq7J/n5M0O++dKNLixfCP/5hOHDbr7RXK7C1vScfm2mtTvYRZ\n2Q04py9pBPAYcALwFLASmBER66u2uQZ4ISL+QdJhwFcj4sQ+3mcL8L6IeN39HFtxTr9RvW1t6Fy1\na+1ouOb0pwCPR8SmiNgKLALO6LXNRGAZQERsAA6W1HvceyLwRF8JvxU1sretDV2l1+7MmbBqVdHR\nmBWnlqS/P1CdqLdky6qtBqYDSJoCHAiM77XN2UBb/IHdytW2ZeaqXbP8TuReBcyXtApYA/QA2ysr\nJe0OnA7s8jqKefPm7Xzc0dFBR0dHTuHlx9W2rc1Vu9bKuru76e7uHtJ71DKnfzQwLyI6s+eXAxER\n/dabSvoFcGREvJg9Px34ZOU9+nlN08/pu9q2fbhq19rBcM3prwQmSDpI0ihgBnBXrw8enY3mkTQL\nWF5J+JmZtPjUzrZtcPbZMGFCah1ore3KK2Hy5DTyf+WVoqMxa5yaKnIldQLzSV8SCyLiKkkXkUb8\nN2Z/DSwEdgBrgQsj4rnstXsCm4BDI+KFXXxG0470XW3bnly1a63Ot2EYJldckaYCli3zVEC7qUzZ\nTZ2aum+ZtZLBJH1X5A7A1bbtzVW7VjZO+rtQqba97z5X27azStXu1Kmw337utWvtzUm/H8uXwyc+\nkaptm7mZueXjoIPg7rtT3cXYsa7atfblU1d9cLVtOblq18rASb8XV9uWm6t2rd15eqeKq20NXLVr\n7c2XbGZcbWu9Vap277jDLRcr9tvP9QzNxNfpD9K2bTB9eirScW9bq4iAOXNSNzSDl19Og6I770yd\nyax4TvqD4Gpbs9q88gq8//3parYbbgDVlWpsOLhH7iDMnQurV7u3rdlARo1KtSsPP5zuXWStqdR/\npFWqbe+/39W2ZrXYa69Uz3DMMWl+/6KLio7I6lXapF9dbeurM8xqN25cOsE9bVp6fOaZRUdk9Shl\n0ne1rdnQTJiQzoGdeiq89a3pFhbWGko3p+9qW7N8vPe98O//nuoa1q4tOhqrVamS/qZNrrY1y9NJ\nJ8G118Ipp8DmzQNvb8UrzfTOs8+mCsuuLlfbmuXpnHPg179Ov1/33Qf77lt0RLYrpbhO39W2ZsOv\nqwsefBCWLoU99ig6mnJwcVYfKi3x9tnH1bZmw2nHDjjvPPjDH1y12yguzuolIl2ls3UrLFjghG82\nnEaMgJtvTkn/4ovT7581n7ZOg1dc4Wpbs0Zy1W7za9s/wK6/Hm67zdW2Zo3mqt3m1pZJf/Fi+MIX\nXG1rVhRX7Tavtkv6rrY1aw6u2m1ObTWnX6m2vfVWV9uaNQNX7Taftkn61dW2J55YdDRmVnHSSfDF\nL7pqt1m0xfSOq23Nmtu5575atXv//aluxorR8sVZL72U7qMzbRpcc80wBGZmuenqgocegiVLXLWb\nh9JV5Lra1qy1uGo3X8NWkSupU9J6SY9JuqyP9WMkLZa0WtIKSROr1o2WdIekdZLWSnpfPQH2x9W2\nZq3HVbvFGzBVShoBfAU4GZgEzJR0eK/NPgf0RMRk4Hzguqp184G7I+IIYDKwLo/AXW1r1ppctVus\nWsbHU4DHI2JTRGwFFgFn9NpmIrAMICI2AAdLGitpb2BaRNycrdsWEc8PNejrr0+9bX/4Q1fbmrWi\nStXuLbfAjTcWHU251DKjtj9QfaHVFtIXQbXVwHTgAUlTgAOB8cAO4P9Jupk0yn8YmBMR/z3YgL/7\nXVfbmrWDceNSEeWxx8If/wiHHJL/Z+y9d3p/1TXr3d7yOo1yFTBf0ipgDdADbAd2B44CLo6IhyV9\nGbgcmNvXm8ybN2/n446ODjo6Ol63zbJl8B//4Wpbs3YwYUL6fb7ySti+Pf/3X7kyFWsef3z+712E\n7u5uuru7h/QeA169I+loYF5EdGbPLwciIvptRyLpF8CRwJuA/xsRh2bLpwKXRcQH+njNsDVRMbNy\nuukm+N730lRwOxquq3dWAhMkHSRpFDADuKvXB4+WtHv2eBawPCJejIhngM2S3pVtegLw83oCNDMb\nrPPOg0cegZ876+w0YNKPiO3AJcASYC2wKCLWSbpI0sezzY4AHpW0jnSVz5yqt5gNfFvST0nz+v87\nzx0wM+vPG98In/xkat5uSUsXZ5mZDeS3v4V3vQvWr08nj9uJ2yWamfUydiycfTZ89atFR9IcPNI3\ns7a3YUO6P9cvfwl77ll0NPnxSN/MrA+HHQZ/8RepGKzsPNI3s1K49174m79Jc/vtcq8uj/TNzPox\nbRqMHp1aOJaZk76ZlYKU7uf/xS8WHUmxnPTNrDTOOiu1Vl25suhIiuOkb2alMXIkfPrT5R7t+0Su\nmZXK88+nO3o+8ggcfHDR0QyNT+SamQ1g773hggtg/vyiIymGR/pmVjqbN8PkybBxI4wZU3Q0g+eR\nvplZDQ44AE45Jd16uWw80jezUlq1Cs44A554IvXtbUUe6ZuZ1eioo+Cd74Tbby86ksZy0jez0qoU\na5VpksFJ38xK65RTUlP2e+4pOpLGcdI3s9IaMQL+9m/LVazlE7lmVmp//GMq0lq2DCZOLDqa+vhE\nrplZncrWR9cjfTMrvVbto+uRvpnZIJSpj65H+mZmtGYfXY/0zcwGqdJHd+HCoiMZXh7pm5llWq2P\nrkf6ZmZDUIY+uk76ZmaZMvTRddI3M6tS6aP70ENFRzI8nPTNzKq0ex/dmk7kSuoEvkz6klgQEVf3\nWj8G+AbwP4D/Bi6IiJ9n634JPAfsALZGxJR+PsMncs2sKbRKH93BnMgdMOlLGgE8BpwAPAWsBGZE\nxPqqba4BXoiIf5B0GPDViDgxW7cR+LOI+K8BPsdJ38yaxqWXwrZt8KUvFR1J/4br6p0pwOMRsSki\ntgKLgDN6bTMRWAYQERuAgyWNrcRV4+eYmTWN2bPTNfu//33RkeSrlmS8P7C56vmWbFm11cB0AElT\ngAOB8dm6AJZKWilp1tDCNTNrjHbtozsyp/e5CpgvaRWwBugBtmfrjomIp7OR/1JJ6yLi/r7eZN68\neTsfd3R00NHRkVN4Zmb16+qC00+HOXOao49ud3c33d3dQ3qPWub0jwbmRURn9vxyIHqfzO31ml8A\nR0bEi72WzyXN/b/uJqae0zezZnT88XDBBXDuuUVH8nrDNae/Epgg6SBJo4AZwF29Pni0pN2zx7OA\n5RHxoqQ9Jb05W/4m4CTg0XoCNDMrUrv10R0w6UfEduASYAmwFlgUEeskXSTp49lmRwCPSloHnAzM\nyZaPA+6X1AOsAH4QEUvy3gkzs+HSbn10fcM1M7MB3HQTfO978MMfFh3Jaw3LdfqN4qRvZs2qWfvo\n+i6bZmbDoJ366Hqkb2ZWg2bso+uRvpnZMGmXProe6ZuZ1ajZ+uh6pG9mNozaoY+uR/pmZnVopj66\nHumbmQ2zVu+j66RvZlaHVu+j66RvZlanVu6j66RvZlanVu6j6xO5ZmaD0Ax9dH0i18ysQfbeO91n\nf/78oiOpj0f6ZmaDtHkzTJ4MGzfCmDGN/3yP9M3MGqgV++h6pG9mNgSrVqU+uhs3Nr6Prkf6ZmYN\ndtRR6e6bt99edCS1cdI3MxuiVuqj66RvZjZErdRH10nfzGyIRoxonVsz+ESumVkOiuij6xO5ZmYF\naZU+uh7pm5nlpNF9dD3SNzMrUCv00fVI38wsR43so+uRvplZwSp9dG+5pehI+uaRvplZzhrVR9cj\nfTOzJtDMfXRrSvqSOiWtl/SYpMv6WD9G0mJJqyWtkDSx1/oRklZJuiuvwM3MmlUz99EdMOlLGgF8\nBTgZmATMlHR4r80+B/RExGTgfOC6XuvnAD8ferhmZq3hrLPgySdh5cqiI3mtWkb6U4DHI2JTRGwF\nFgFn9NpmIrAMICI2AAdLGgsgaTxwKvBvuUVtZtbkRo6EOXOab7RfS9LfH9hc9XxLtqzaamA6gKQp\nwIHA+Gzdl4BLAZ+lNbNSufBCWLo0Xb7ZLPI6kXsVsI+kVcDFQA+wXdL7gWci4qeAsh8zs1Joxj66\nI2vY5lekkXvF+GzZThHxAnBB5bmkjcBGYAZwuqRTgT2AvSTdEhF/3dcHzZs3b+fjjo4OOjo6atoJ\nM7NmNXt26qM7d+7Q++h2d3fT3d09pPcY8Dp9SbsBG4ATgKeBh4CZEbGuapvRwEsRsVXSLOCYiPho\nr/c5DuiKiNP7+Rxfp29mbencc1Piv/TSfN93WK7Tj4jtwCXAEmAtsCgi1km6SNLHs82OAB6VtI50\nlc+c+kI3M2tfXV1w3XWwdWvRkbgi18ysIY4/Pp3YPeec/N7TFblmZk2qqwv+5V+K76PrpG9m1gDN\n0kfXSd/MrAGapY+u5/TNzBqk0kf3nnvgiCOG/n6e0zcza2LN0EfXI30zswbKs4+uR/pmZk2u6D66\nHumbmTVYXn10PdI3M2sBRfbR9UjfzKwAefTR9UjfzKxFFNVH10nfzKwARfXRddI3MytIEX10nfTN\nzApSRB9dn8g1MyvQ88/DIYfAqlVw0EH1vdYncs3MWkyj++h6pG9mVrDNm1M7xY0b6+uj65G+mVkL\nOuAAOPVUuOmm4f8sj/TNzJpATw+cfnoa7e++e22v8UjfzKxFvec98M53wu23D+/nOOmbmTWJri64\n++7h/QxP75iZNYlKClSNEzaDmd4ZWW9QZmY2PGpN9kPh6R0zsxJx0jczKxEnfTOzEnHSNzMrESd9\nM7MSqSnpS+qUtF7SY5Iu62P9GEmLJa2WtELSxGz5GyQ9KKlH0hpJc/PeATMzq92ASV/SCOArwMnA\nJGCmpMN7bfY5oCciJgPnA9cBRMTLwF9GxHuAdwOnSJqSY/wtobu7u+gQhpX3r7V5/8qllpH+FODx\niNgUEVuBRcAZvbaZCCwDiIgNwMGSxmbPX8q2eQOpLqB0FVjt/j+d96+1ef/KpZakvz+wuer5lmxZ\ntdXAdIBsJH8gMD57PkJSD/BrYGlENLAxmJmZVcvrRO5VwD6SVgEXAz3AdoCI2JFN74wH3leZ7zcz\ns8Yb8N47ko4G5kVEZ/b8ciAi4updvOYXwJER8WKv5Z8H/hAR1/bxmtJN+5iZDdVw3HtnJTBB0kHA\n08AMYGb1BpJGAy9FxFZJs4DlEfGipLcCWyPiOUl7AP+L9FfBkAM3M7P6DZj0I2K7pEuAJaTpoAUR\nsU7SRWl13AgcASyUtANYC1yYvfzt2fIR2Wtvi4hhvnGomZn1p2lurWxmZsOv8IrcgQq/Wp2kX2ZF\naz2SHio6nqGStEDSM5J+VrVsH0lLJG2Q9JNsuq8l9bN/cyVtkbQq++ksMsbBkjRe0jJJa7NiydnZ\n8rY4fn3s36ey5e1y/Posdq33+BU60s+mfR4DTgCeIp0/mBER6wsLKmeSNgJ/FhH/VXQseZA0FXgR\nuCUi/jRbdjXwbERck31x7xMRlxcZ52D1s39zgRf6ugChlUjaD9gvIn4q6c3AI6Sam4/RBsdvF/t3\nNm1w/AAk7RkRL0naDXgAmA18iDqOX9Ej/VoKv1qdKP6/c24i4n6g9xfYGcDC7PFC4MyGBpWjfvYP\n0nFsaRHx64j4afb4RWAd6VLqtjh+/exfpaao5Y8f9FvsWtfxKzoZ1VL41eoCWCppZXZlUzt6W0Q8\nA+kXD3hbwfEMh0sk/VTSv7Xq9Ec1SQeTbo2yAhjXbsevav8ezBa1xfHrp9i1ruNXdNIvg2Mi4ijg\nVODibPqg3bXb1QHXA4dGxLtJv2wtPU2QTX3cCczJRsS9j1dLH78+9q9tjl+vYtcpkiZR5/ErOun/\ninTLhorx2bK2ERFPZ//+Fvg/pCmtdvOMpHGwc171NwXHk6uI+G28evLrJuDPi4xnKCSNJCXEb0XE\n97PFbXP8+tq/djp+FRHxPNANdFLn8Ss66e8s/JI0ilT4dVfBMeVG0p7ZqANJbwJOAh4tNqpciNfO\nkd4FfDR7fD7w/d4vaDGv2b/sF6liOq19DL8B/Dwi5lcta6fj97r9a5fjJ+mtlampqmLXddR5/Aq/\nTj+7fGo+rxZ+9Vmx24okHUIa3QfppMu3W33/JN0KdABvAZ4B5gLfA+4ADgA2AR+OiN8XFeNQ9LN/\nf0maH94B/BK4qDKH2kokHQPcC6wh/T8ZpNuiPwTcTosfv13s30doj+N3JOlEbXWx6xck7Usdx6/w\npG9mZo1T9PSOmZk1kJO+mVmJOOmbmZWIk76ZWYk46ZuZlYiTvplZiTjpm5mViJO+mVmJ/H8IRNaY\nteWYPgAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot the K values (x-axis) versus the 5-fold CV score (y-axis)\n", "plt.figure()\n", "plt.plot(k_range, scores)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# automatic grid search for an optimal value of K\n", "from sklearn.grid_search import GridSearchCV\n", "knn = KNeighborsClassifier()\n", "k_range = range(1, 30, 2)\n", "param_grid = dict(n_neighbors=k_range)\n", "grid = GridSearchCV(knn, param_grid, cv=5, scoring='accuracy')\n", "grid.fit(X, y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# this will check K=1, K=2, all the way up to 30,\n", "# and then do cross validation on each one!\n", "# thats 30 * 5 = 150 fits and scoring!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# check the results of the grid search\n", "grid.grid_scores_\n", "grid_mean_scores = [result[1] for result in grid.grid_scores_]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# plot the results\n", "plt.figure()\n", "plt.plot(k_range, grid_mean_scores)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "grid.best_score_ # shows us the best score" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "grid.best_params_ # shows us the optimal parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "grid.best_estimator_ # this is the actual model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [sfdat26-env]", "language": "python", "name": "Python [sfdat26-env]" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 0 }