{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# GridSearch\n", "> [公式ドキュメント](http://scikit-learn.org/stable/modules/grid_search.html#grid-search)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "下ごしらえ" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import pandas as pd\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from sklearn import cross_validation as cv\n", "from sklearn.utils import shuffle" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "from sklearn.svm import SVC\n", "clf=SVC(probability=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "from sklearn.datasets import load_iris\n", "data=load_iris().data\n", "target=load_iris().target" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# 今回のお題" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "from sklearn import grid_search as grid" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### モデルのインスタンスはget_params()で設定できるパラメータを表示できる" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'C': 1.0,\n", " 'cache_size': 200,\n", " 'class_weight': None,\n", " 'coef0': 0.0,\n", " 'decision_function_shape': None,\n", " 'degree': 3,\n", " 'gamma': 'auto',\n", " 'kernel': 'rbf',\n", " 'max_iter': -1,\n", " 'probability': True,\n", " 'random_state': None,\n", " 'shrinking': True,\n", " 'tol': 0.001,\n", " 'verbose': False}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.get_params()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### gridで設定したいパラメータを辞書形式で定義する" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "parameters = {\n", " 'kernel':('linear', 'rbf'),\n", " 'C': np.linspace(1,10,5),\n", " 'gamma' : np.append(\n", " np.logspace(-4,1,11).astype('object'),\n", " 'auto'\n", " )\n", "}" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### パラメータグリッドをまとめたモデル群のインスタンスを作成\n", "\n", "- クロスバリデーションまでまとめて実行可能\n", "\n", "- デフォルトはcv=None" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [], "source": [ "grid_clf=grid.GridSearchCV(clf,parameters,n_jobs=-1,cv=5)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "中身" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]),\n", " 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01,\n", " 0.03162277660168379, 0.1, 0.31622776601683794, 1.0,\n", " 3.1622776601683795, 10.0, 'auto'], dtype=object),\n", " 'kernel': ('linear', 'rbf')}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_clf.param_grid" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "フィッティング" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", " max_iter=-1, probability=True, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False),\n", " fit_params={}, iid=True, n_jobs=-1,\n", " param_grid={'kernel': ('linear', 'rbf'), 'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]), 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01,\n", " 0.03162277660168379, 0.1, 0.31622776601683794, 1.0,\n", " 3.1622776601683795, 10.0, 'auto'], dtype=object)},\n", " pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_shuffle,target_shuffle=shuffle(data,target)\n", "grid_clf.fit(data_shuffle,target_shuffle)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "探索結果は`grid_scores_`に入っている" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
parametersmean_validation_scorecv_validation_scores
0{'kernel': 'linear', 'C': 1.0, 'gamma': 0.0001}0.973333[0.966666666667, 0.966666666667, 0.93333333333...
1{'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0001}0.920000[0.9, 0.9, 0.9, 0.933333333333, 0.966666666667]
2{'kernel': 'linear', 'C': 1.0, 'gamma': 0.0003...0.973333[0.966666666667, 0.966666666667, 0.93333333333...
3{'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0003162...0.920000[0.9, 0.9, 0.9, 0.933333333333, 0.966666666667]
4{'kernel': 'linear', 'C': 1.0, 'gamma': 0.001}0.973333[0.966666666667, 0.966666666667, 0.93333333333...
\n", "
" ], "text/plain": [ " parameters mean_validation_score \\\n", "0 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0001} 0.973333 \n", "1 {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0001} 0.920000 \n", "2 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0003... 0.973333 \n", "3 {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0003162... 0.920000 \n", "4 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.001} 0.973333 \n", "\n", " cv_validation_scores \n", "0 [0.966666666667, 0.966666666667, 0.93333333333... \n", "1 [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] \n", "2 [0.966666666667, 0.966666666667, 0.93333333333... \n", "3 [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] \n", "4 [0.966666666667, 0.966666666667, 0.93333333333... " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(grid_clf.grid_scores_).head()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "ベストのモデルは別名が付いている" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "scrolled": true, "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'kernel': 'rbf', 'C': 7.75, 'gamma': 0.03162277660168379}\n", "0.986666666667\n" ] } ], "source": [ "print(grid_clf.best_params_)\n", "print(grid_clf.best_score_)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# ランダムフォレストの場合\n", "> iris(4次元)でRF使わないですが..." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "scrolled": true, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'oob_score': False, 'max_features': 'auto', 'criterion': 'gini', 'n_jobs': 1, 'max_depth': None, 'min_samples_split': 2, 'verbose': 0, 'min_samples_leaf': 1, 'bootstrap': True, 'max_leaf_nodes': None, 'class_weight': None, 'warm_start': False, 'min_weight_fraction_leaf': 0.0, 'random_state': None, 'n_estimators': 10}\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestClassifier as RFC\n", "clf_rf=RFC()\n", "print(clf_rf.get_params())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "scrolled": true, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n", " oob_score=False, random_state=None, verbose=0,\n", " warm_start=False),\n", " fit_params={}, iid=True, n_jobs=-1,\n", " param_grid={'n_estimators': [1, 2, 4, 8, 16, 32, 64, 128, 256]},\n", " pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "parameters_rf={'n_estimators':[1*2**i for i in range(9)]}\n", "grid_clf_rf=grid.GridSearchCV(clf_rf,parameters_rf,n_jobs=-1,cv=5)\n", "grid_clf_rf.fit(data_shuffle,target_shuffle)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
parametersmean_validation_scorecv_validation_scores
0{'n_estimators': 1}0.920000[0.833333333333, 0.866666666667, 0.93333333333...
1{'n_estimators': 2}0.906667[0.9, 0.833333333333, 0.833333333333, 0.966666...
2{'n_estimators': 4}0.933333[0.9, 0.866666666667, 0.966666666667, 0.966666...
3{'n_estimators': 8}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
4{'n_estimators': 16}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
5{'n_estimators': 32}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
6{'n_estimators': 64}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
7{'n_estimators': 128}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
8{'n_estimators': 256}0.953333[0.933333333333, 0.866666666667, 0.96666666666...
\n", "
" ], "text/plain": [ " parameters mean_validation_score \\\n", "0 {'n_estimators': 1} 0.920000 \n", "1 {'n_estimators': 2} 0.906667 \n", "2 {'n_estimators': 4} 0.933333 \n", "3 {'n_estimators': 8} 0.953333 \n", "4 {'n_estimators': 16} 0.953333 \n", "5 {'n_estimators': 32} 0.953333 \n", "6 {'n_estimators': 64} 0.953333 \n", "7 {'n_estimators': 128} 0.953333 \n", "8 {'n_estimators': 256} 0.953333 \n", "\n", " cv_validation_scores \n", "0 [0.833333333333, 0.866666666667, 0.93333333333... \n", "1 [0.9, 0.833333333333, 0.833333333333, 0.966666... \n", "2 [0.9, 0.866666666667, 0.966666666667, 0.966666... \n", "3 [0.933333333333, 0.866666666667, 0.96666666666... \n", "4 [0.933333333333, 0.866666666667, 0.96666666666... \n", "5 [0.933333333333, 0.866666666667, 0.96666666666... \n", "6 [0.933333333333, 0.866666666667, 0.96666666666... \n", "7 [0.933333333333, 0.866666666667, 0.96666666666... \n", "8 [0.933333333333, 0.866666666667, 0.96666666666... " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(grid_clf_rf.grid_scores_)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# 終わり\n", "@y__sama" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }