{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Support Vector Machines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example we are going to use sklearn data loading utility instead of our trustworthy Pandas. Pandas is a great tool, but it can become a crutch. I've spent several coding hours just to have datasets comply with Pandas format." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "import pylab as plt\n", "%matplotlib inline " ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "digits = load_digits() #This is part of the MNIST dataset" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'images': array([[[ 0., 0., 5., ..., 1., 0., 0.],\n", " [ 0., 0., 13., ..., 15., 5., 0.],\n", " [ 0., 3., 15., ..., 11., 8., 0.],\n", " ..., \n", " [ 0., 4., 11., ..., 12., 7., 0.],\n", " [ 0., 2., 14., ..., 12., 0., 0.],\n", " [ 0., 0., 6., ..., 0., 0., 0.]],\n", "\n", " [[ 0., 0., 0., ..., 5., 0., 0.],\n", " [ 0., 0., 0., ..., 9., 0., 0.],\n", " [ 0., 0., 3., ..., 6., 0., 0.],\n", " ..., \n", " [ 0., 0., 1., ..., 6., 0., 0.],\n", " [ 0., 0., 1., ..., 6., 0., 0.],\n", " [ 0., 0., 0., ..., 10., 0., 0.]],\n", "\n", " [[ 0., 0., 0., ..., 12., 0., 0.],\n", " [ 0., 0., 3., ..., 14., 0., 0.],\n", " [ 0., 0., 8., ..., 16., 0., 0.],\n", " ..., \n", " [ 0., 9., 16., ..., 0., 0., 0.],\n", " [ 0., 3., 13., ..., 11., 5., 0.],\n", " [ 0., 0., 0., ..., 16., 9., 0.]],\n", "\n", " ..., \n", " [[ 0., 0., 1., ..., 1., 0., 0.],\n", " [ 0., 0., 13., ..., 2., 1., 0.],\n", " [ 0., 0., 16., ..., 16., 5., 0.],\n", " ..., \n", " [ 0., 0., 16., ..., 15., 0., 0.],\n", " [ 0., 0., 15., ..., 16., 0., 0.],\n", " [ 0., 0., 2., ..., 6., 0., 0.]],\n", "\n", " [[ 0., 0., 2., ..., 0., 0., 0.],\n", " [ 0., 0., 14., ..., 15., 1., 0.],\n", " [ 0., 4., 16., ..., 16., 7., 0.],\n", " ..., \n", " [ 0., 0., 0., ..., 16., 2., 0.],\n", " [ 0., 0., 4., ..., 16., 2., 0.],\n", " [ 0., 0., 5., ..., 12., 0., 0.]],\n", "\n", " [[ 0., 0., 10., ..., 1., 0., 0.],\n", " [ 0., 2., 16., ..., 1., 0., 0.],\n", " [ 0., 0., 15., ..., 15., 0., 0.],\n", " ..., \n", " [ 0., 4., 16., ..., 16., 6., 0.],\n", " [ 0., 8., 16., ..., 16., 8., 0.],\n", " [ 0., 1., 8., ..., 12., 1., 0.]]]), 'data': array([[ 0., 0., 5., ..., 0., 0., 0.],\n", " [ 0., 0., 0., ..., 10., 0., 0.],\n", " [ 0., 0., 0., ..., 16., 9., 0.],\n", " ..., \n", " [ 0., 0., 1., ..., 6., 0., 0.],\n", " [ 0., 0., 2., ..., 12., 0., 0.],\n", " [ 0., 0., 10., ..., 12., 1., 0.]]), 'target_names': array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 'DESCR': \"Optical Recognition of Handwritten Digits Data Set\\n===================================================\\n\\nNotes\\n-----\\nData Set Characteristics:\\n :Number of Instances: 5620\\n :Number of Attributes: 64\\n :Attribute Information: 8x8 image of integer pixels in the range 0..16.\\n :Missing Attribute Values: None\\n :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)\\n :Date: July; 1998\\n\\nThis is a copy of the test set of the UCI ML hand-written digits datasets\\nhttp://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\\n\\nThe data set contains images of hand-written digits: 10 classes where\\neach class refers to a digit.\\n\\nPreprocessing programs made available by NIST were used to extract\\nnormalized bitmaps of handwritten digits from a preprinted form. From a\\ntotal of 43 people, 30 contributed to the training set and different 13\\nto the test set. 32x32 bitmaps are divided into nonoverlapping blocks of\\n4x4 and the number of on pixels are counted in each block. This generates\\nan input matrix of 8x8 where each element is an integer in the range\\n0..16. This reduces dimensionality and gives invariance to small\\ndistortions.\\n\\nFor info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.\\nT. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.\\nL. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,\\n1994.\\n\\nReferences\\n----------\\n - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their\\n Applications to Handwritten Digit Recognition, MSc Thesis, Institute of\\n Graduate Studies in Science and Engineering, Bogazici University.\\n - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.\\n - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.\\n Linear dimensionalityreduction using relevance weighted LDA. School of\\n Electrical and Electronic Engineering Nanyang Technological University.\\n 2005.\\n - Claudio Gentile. A New Approximate Maximal Margin Classification\\n Algorithm. NIPS. 2000.\\n\", 'target': array([0, 1, 2, ..., 8, 9, 8])}\n" ] } ], "source": [ "print digits" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 0. 5. ..., 0. 0. 0.]\n", " [ 0. 0. 0. ..., 10. 0. 0.]\n", " [ 0. 0. 0. ..., 16. 9. 0.]\n", " ..., \n", " [ 0. 0. 1. ..., 6. 0. 0.]\n", " [ 0. 0. 2. ..., 12. 0. 0.]\n", " [ 0. 0. 10. ..., 12. 1. 0.]]\n" ] } ], "source": [ "print digits.data" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1797, 64)\n" ] } ], "source": [ "print digits.data.shape #the shape is not 8x8, instead is 1797 x 64\n", "#As long as features don't mix, you can reshape your data as you want in SVMs\n", "#We refer to this as column-switching invariance." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPcAAAD7CAYAAAC2TgIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAC15JREFUeJzt3V2MVdUZxvHnQSgBiYRUtEbqBzZowg0CFRP8rB/1I9Ir\nUGtiyoV40QbSJgbCTe8a7wxJvTEqVYo2GVKixqYBJNpoA5UBFAesrdSKRYhGojGYRsvbi7MxpCWZ\nPZy91sy8/n/JyWwmM+d9z8x5Zq2zz2YtR4QA5DNhtBsAUAbhBpIi3EBShBtIinADSRFuIKlRC7ft\n22y/bfsd26sL13rC9lHbb5asc0q9Wba32x6yvc/2ysL1JtveaXtPU/NXJes1NSfY3m37+dK1mnrv\n2X6jeYx/KVxruu0B2wean+eigrXmNI9pd/Px086eLxFR/abeH5W/S7pY0iRJeyVdUbDeNZLmSXqz\n0uP7jqR5zfE0SX8t+fiaOlObj2dJ2iFpceF6P5f0W0nPV/qZHpQ0o1Kt30ha3hxPlHROpboTJB2W\n9N0u7m+0Ru6rJP0tIv4ZEV9K+p2kH5UqFhGvSjpW6v5PU+9IROxtjj+XdEDShYVrHm8OJ6v3JCn2\neG3PknSHpMdL1ThdWVWYado+R9K1EbFekiLiq4j4rHTdxs2S3o2IQ13c2WiF+0JJpz6AD1T4yT9a\nbF+i3qxhZ+E6E2zvkXRE0ssRsb9guUckPSSp5uWNIWmr7ddtP1CwzqWSPra9vpkqP2Z7SsF6p7pb\n0rNd3Rkn1AqyPU3SJkmrmhG8mIg4ERFXSpol6Trb15eoY/tOSUebmYmbWw2LI2K+ejOGn9q+plCd\niZLmS3q0qXdc0ppCtb5me5KkJZIGurrP0Qr3vyRddMq/ZzWfS8P2RPWCvSEinqtVt5lCvihpYaES\niyUtsX1QvVHmRttPF6r1tYj4sPn4kaTN6r20K+EDSYciYlfz703qhb202yUNNo+vE6MV7tclfc/2\nxba/JekeSaXPutYcZSTpSUn7I2Jd6UK2z7U9vTmeIukW9U5Sdi4i1kbERRExW73f2/aIuL9ErZNs\nT21mQbJ9tqRbJb1VolZEHJV0yPac5lM3SSr5Eueke9XhlFzqTUGqi4j/2P6ZpC3q/YF5IiIOlKpn\n+xlJN0j6tu33Jf3y5AmTQvUWS7pP0r7mdXBIWhsRfyxU8gJJT9k+edJpQ0S8VKjWaDhf0mbbod5z\ndmNEbClYb6Wkjc1U+aCk5QVryfZU9U6mrej0fptT8ACS4YQakBThBpIi3EBShBtIinADSXX2Vljz\nNgWAURAR/3cNx6i8zz0eLV26dMTfMzQ0pLlz555RvYcffnjE37Nu3TqtWrXqjOpt27ZtxN/zwgsv\n6K677jqjemvWjPyKzi+++EJTppzZZd7HjlX7f0NjBtNyICnCDSRFuAuaOXNm1XqLFhVbMOS05syZ\nM/wXdWjiRF5FjgThLui8886rWu/qq6+uWu/yyy+vWm/SpElV6413hBtIinADSRFuIKlW4a65DDGA\nbgwbbtsTJP1a0g8lzZV0r+0rSjcGoD9tRu6qyxAD6EabcH9jliEGMuGEGpBUm3CnX4YYyKhNuEdj\nGWIAfRr2Yt3ayxAD6EarK/Gb9bbrXkgMoC+cUAOSItxAUoQbSIpwA0kRbiApwg0kRbiBpAg3kBTh\nBpJirdiWzmQHkH7Mnj27ar0ZM2ZUrffJJ59Urbds2bKq9QYGBqrWOx1GbiApwg0kRbiBpAg3kBTh\nBpIi3EBShBtIinADSRFuIKk22wk9Yfuo7TdrNASgG21G7vXq7RMGYBwZNtwR8aqkYxV6AdAhXnMD\nSRFuICnCDSTVNtxubgDGiTZvhT0j6c+S5th+3/by8m0B6FebjQB/XKMRAN3iNTeQFOEGkiLcQFKE\nG0iKcANJEW4gKcINJEW4gaQIN5DUuN0rbMGCBVXr1d6767LLLqta7+DBg1Xrbd26tWq92s8X9goD\nUAzhBpIi3EBShBtIinADSRFuICnCDSRFuIGkCDeQVJsFEmfZ3m57yPY+2ytrNAagP20uP/1K0i8i\nYq/taZIGbW+JiLcL9wagD232CjsSEXub488lHZB0YenGAPRnRK+5bV8iaZ6knSWaAdCd1uFupuSb\nJK1qRnAAY1ircNueqF6wN0TEc2VbAtCFtiP3k5L2R8S6ks0A6E6bt8IWS7pP0g9s77G92/Zt5VsD\n0I82e4W9JumsCr0A6BBXqAFJEW4gKcINJEW4gaQIN5AU4QaSItxAUoQbSIpwA0mN273CZsyYUbXe\n4OBg1Xq19+6qrfbP85uIkRtIinADSRFuICnCDSRFuIGkCDeQFOEGkiLcQFKEG0hq2CvUbE+W9CdJ\n32puz0XE2tKNAehPmwUS/237xog4bvssSa/ZXtwsnAhgjGo1LY+I483h5OZ7jhXrCEAn2u44MsH2\nHklHJL0cEfvLtgWgX21H7hMRcaWkWZKus3192bYA9GtEZ8sj4jNJL0paWKYdAF1ps53QubanN8dT\nJN0iaW/pxgD0p81iDRdIesq21ftjsCEiXirbFoB+tXkrbJ+k+RV6AdAhrlADkiLcQFKEG0iKcANJ\nEW4gKcINJEW4gaQIN5AU4QaSYq+wlrZt21a1Xna1f3/Hjn3zliBg5AaSItxAUoQbSIpwA0kRbiAp\nwg0kRbiBpAg3kBThBpJqHe5mY4Ldtp8v2RCAboxk5F4liZ1GgHGi7XZCsyTdIenxsu0A6ErbkfsR\nSQ9JioK9AOhQmx1H7pR0NCL2SnJzAzDGtRm5F0taYvugpGcl3Wj76bJtAejXsOGOiLURcVFEzJZ0\nj6TtEXF/+dYA9IP3uYGkRrQSS0S8IumVQr0A6BAjN5AU4QaSItxAUoQbSIpwA0kRbiApwg0kRbiB\npAg3kNS43Sus9t5PCxYsqFqvttp7d9X+eQ4MDFStNxYwcgNJEW4gKcINJEW4gaQIN5AU4QaSItxA\nUoQbSIpwA0m1ukLN9nuSPpV0QtKXEXFVyaYA9K/t5acnJN0QEXWv+QRwxtpOyz2CrwUwBrQNbEja\navt12w+UbAhAN9pOyxdHxIe2Z6oX8gMR8WrJxgD0p9XIHREfNh8/krRZEifUgDGuzS6fU21Pa47P\nlnSrpLdKNwagP22m5edL2mw7mq/fGBFbyrYFoF/Dhjsi/iFpXoVeAHSIt7eApAg3kBThBpIi3EBS\nhBtIinADSRFuICnCDSRFuIGkHBHd3FHv8tRqZs+eXbOcdu3aVbXegw8+WLXe0qVLq9ar/ftbuHBh\n1Xq1RYT/93OM3EBShBtIinADSRFuICnCDSRFuIGkCDeQFOEGkiLcQFKtwm17uu0B2wdsD9leVLox\nAP1puynBOkl/iIiltidKmlqwJwAdGDbcts+RdG1E/ESSIuIrSZ8V7gtAn9pMyy+V9LHt9bZ3237M\n9pTSjQHoT5twT5Q0X9KjETFf0nFJa4p2BaBvbcL9gaRDEXHy/zxuUi/sAMawYcMdEUclHbI9p/nU\nTZL2F+0KQN/ani1fKWmj7UmSDkpaXq4lAF1oFe6IeEPS9wv3AqBDXKEGJEW4gaQIN5AU4QaSItxA\nUoQbSIpwA0kRbiApwg0kNW73CqttxYoVVeutXr26ar3BwcGq9ZYtW1a1XnbsFQZ8gxBuICnCDSRF\nuIGkCDeQFOEGkiLcQFKEG0hq2HDbnmN7T7Nm+R7bn9peWaM5AGdu2DXUIuIdSVdKku0J6i11vLlw\nXwD6NNJp+c2S3o2IQyWaAdCdkYb7bknPlmgEQLdah7tZs3yJpIFy7QDoykhG7tslDUbER6WaAdCd\nkYT7XjElB8aNVuG2PVW9k2m/L9sOgK603U7ouKSZhXsB0CGuUAOSItxAUoQbSIpwA0kRbiApwg0k\nRbgLOnz4cNV6O3bsqFpvaGioaj2MDOEuqHa4d+7cWbUe4R7bCDeQFOEGkmKvMCCB0+0V1lm4AYwt\nTMuBpAg3kBThBpIi3EBShBtI6r86MxYZKbAIsgAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.gray() #Gray just changes the scale to grayscale plotting, instead of RGB colors.\n", "plt.matshow(digits.images[0]) \n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 1 2 ..., 8 9 8]\n" ] } ], "source": [ "print digits.target #target has the real labels for each image" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.svm import SVC\n", "import numpy as np\n", "from sklearn.metrics import accuracy_score\n", "clf = SVC()\n", "y = digits.target #The labels are our target\n", "X = digits.data #The data us our features (1797, 64)\n", "clf.fit(X, y) #This is as vanilla as it gets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- C: Parameter C that controls the penalty on the Margin\n", "- degree: If we have a polynomial, this sets the degree\n", "- Gamma: Gamma parameter in the RBF equation\n", "\n", "Note that the default kernel is RBF and gamma is 'auto', the 'auto' sets up the kernel as 1/n_features, in our case 1/64 = 0.0156" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 1 2 ..., 8 9 8]\n" ] } ], "source": [ "print clf.predict(X)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_predict = clf.predict(X)\n", "np.sum(y_predict - y) #If each prediction is the same, this should be zero" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Is perfect, but what will happen if we train with half the data, and test the other half of the data." ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(898, 64) (899, 64) (898,) (899,)\n" ] } ], "source": [ "from sklearn.cross_validation import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=30)\n", "#random states are important for repeatability\n", "print X_train.shape, X_test.shape, y_train.shape, y_test.shape" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'rbf'" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.22358175750834261" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "22% Accuracy means it only gets the right result 22 out of 100 times! Is only marginally better than random guessing." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true }, "outputs": [], "source": [ "clf.kernel = 'linear'" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.98109010011123465" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.090100111234705224" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'sigmoid'\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.090100111234705224" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 1\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)\n" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.090100111234705224" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.5\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.090100111234705224" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.1\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.63737486095661844" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.01\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.9899888765294772" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.001\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.967741935483871" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.0001\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.82313681868743049" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.kernel = 'rbf'\n", "clf.gamma = 0.00001\n", "clf.fit(X_train, y_train)\n", "y_predict = clf.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1.00000000e-03 3.16227766e-02 1.00000000e+00 3.16227766e+01\n", " 1.00000000e+03]\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=None, error_score='raise',\n", " estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma=1e-05, kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False),\n", " fit_params={}, iid=True, n_jobs=5,\n", " param_grid={'kernel': ('linear', 'rbf'), 'gamma': array([ 1.00000e-03, 3.16228e-02, 1.00000e+00, 3.16228e+01,\n", " 1.00000e+03])},\n", " pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import grid_search\n", "gamma_list = np.logspace(-3, 3, 5)\n", "print gamma_list\n", "parameters = {'kernel':('linear', 'rbf'), 'gamma':gamma_list}\n", "clgs = grid_search.GridSearchCV(clf, parameters, n_jobs = 5)\n", "clgs.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 91, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)\n" ] }, { "data": { "text/plain": [ "0.9899888765294772" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print clgs.best_estimator_\n", "y_predict = clgs.predict(X_test)\n", "accuracy_score(y_test, y_predict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 0 }