{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Scikit-Learn is simple\n", "=======================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Classification\n", "------------------" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.datasets import load_iris\n", "from sklearn.cross_validation import train_test_split\n", "\n", "\n", "iris = load_iris()\n", "X, y = iris.data, iris.target\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 1 }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.svm import SVC\n", "clf = SVC()\n", "clf.fit(X_train, y_train)\n", "y_pred = clf.predict(X_test)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 2 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transformations\n", "----------------" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.decomposition import PCA" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 3 }, { "cell_type": "code", "collapsed": false, "input": [ "pca = PCA(n_components=2)\n", "pca.fit(X)\n", "X_pca = pca.transform(X)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 4 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tools\n", "======" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross-validation scoring\n", "--------------------------" ] }, { "cell_type": "code", "collapsed": false, "input": [ "import numpy as np\n", "np.set_printoptions(precision=2)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 5 }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.cross_validation import cross_val_score, StratifiedKFold\n", "scores = cross_val_score(SVC(), X_train, y_train, cv=5)\n", "print(scores)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "[ 0.92 1. 0.96 1. 1. ]\n" ] } ], "prompt_number": 6 }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.cross_validation import ShuffleSplit\n", "cv_ss = ShuffleSplit(len(X_train))\n", "scores_shuffle_split = cross_val_score(SVC(), X_train, y_train, cv=cv_ss)\n", "print(scores_shuffle_split)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "[ 1. 1. 1. 1. 1. 0.83 1. 0.92 1. 0.92]\n" ] } ], "prompt_number": 7 }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.cross_validation import LeaveOneLabelOut\n", "labels = np.arange(len(X_train)) % 3\n", "cv_label = LeaveOneLabelOut(labels)\n", "scores_pout = cross_val_score(SVC(), X_train, y_train, cv=cv_label)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 8 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross-validated grid-searches\n", "------------------------------" ] }, { "cell_type": "code", "collapsed": false, "input": [ "import numpy as np\n", "from sklearn.grid_search import GridSearchCV\n", "param_grid = {'C': 10. ** np.arange(-3, 3), 'gamma': 10. ** np.arange(-3, 3)}\n", "grid = GridSearchCV(SVC(), param_grid=param_grid)\n", "grid.fit(X_train, y_train)\n", "print(grid.best_params_)\n", "print(grid.score(X_test, y_test))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "{'C': 100.0, 'gamma': 0.01}\n", "1.0\n" ] } ], "prompt_number": 9 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pipelining\n", "----------" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "pipe = make_pipeline(StandardScaler(), SVC())\n", "pipe.fit(X_train, y_train)\n", "pipe.predict(X_test)" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "pyout", "prompt_number": 10, "text": [ "array([0, 0, 1, 2, 0, 2, 0, 1, 0, 2, 2, 1, 2, 2, 0, 2, 1, 2, 1, 1, 1, 1, 0,\n", " 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1])" ] } ], "prompt_number": 10 } ], "metadata": {} } ] }