{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# GridSearch\n",
"> [公式ドキュメント](http://scikit-learn.org/stable/modules/grid_search.html#grid-search)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"下ごしらえ"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import pandas as pd\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from sklearn import cross_validation as cv\n",
"from sklearn.utils import shuffle"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"clf=SVC(probability=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"data=load_iris().data\n",
"target=load_iris().target"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 今回のお題"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn import grid_search as grid"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### モデルのインスタンスはget_params()で設定できるパラメータを表示できる"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'C': 1.0,\n",
" 'cache_size': 200,\n",
" 'class_weight': None,\n",
" 'coef0': 0.0,\n",
" 'decision_function_shape': None,\n",
" 'degree': 3,\n",
" 'gamma': 'auto',\n",
" 'kernel': 'rbf',\n",
" 'max_iter': -1,\n",
" 'probability': True,\n",
" 'random_state': None,\n",
" 'shrinking': True,\n",
" 'tol': 0.001,\n",
" 'verbose': False}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.get_params()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### gridで設定したいパラメータを辞書形式で定義する"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"parameters = {\n",
" 'kernel':('linear', 'rbf'),\n",
" 'C': np.linspace(1,10,5),\n",
" 'gamma' : np.append(\n",
" np.logspace(-4,1,11).astype('object'),\n",
" 'auto'\n",
" )\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"#### パラメータグリッドをまとめたモデル群のインスタンスを作成\n",
"\n",
"- クロスバリデーションまでまとめて実行可能\n",
"\n",
"- デフォルトはcv=None"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"grid_clf=grid.GridSearchCV(clf,parameters,n_jobs=-1,cv=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"中身"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]),\n",
" 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01,\n",
" 0.03162277660168379, 0.1, 0.31622776601683794, 1.0,\n",
" 3.1622776601683795, 10.0, 'auto'], dtype=object),\n",
" 'kernel': ('linear', 'rbf')}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grid_clf.param_grid"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"フィッティング"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5, error_score='raise',\n",
" estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
" decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n",
" max_iter=-1, probability=True, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False),\n",
" fit_params={}, iid=True, n_jobs=-1,\n",
" param_grid={'kernel': ('linear', 'rbf'), 'C': array([ 1. , 3.25, 5.5 , 7.75, 10. ]), 'gamma': array([0.0001, 0.00031622776601683794, 0.001, 0.0031622776601683794, 0.01,\n",
" 0.03162277660168379, 0.1, 0.31622776601683794, 1.0,\n",
" 3.1622776601683795, 10.0, 'auto'], dtype=object)},\n",
" pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_shuffle,target_shuffle=shuffle(data,target)\n",
"grid_clf.fit(data_shuffle,target_shuffle)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"探索結果は`grid_scores_`に入っている"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"
\n",
" \n",
" \n",
" | \n",
" parameters | \n",
" mean_validation_score | \n",
" cv_validation_scores | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0001} | \n",
" 0.973333 | \n",
" [0.966666666667, 0.966666666667, 0.93333333333... | \n",
"
\n",
" \n",
" 1 | \n",
" {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0001} | \n",
" 0.920000 | \n",
" [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] | \n",
"
\n",
" \n",
" 2 | \n",
" {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0003... | \n",
" 0.973333 | \n",
" [0.966666666667, 0.966666666667, 0.93333333333... | \n",
"
\n",
" \n",
" 3 | \n",
" {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0003162... | \n",
" 0.920000 | \n",
" [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] | \n",
"
\n",
" \n",
" 4 | \n",
" {'kernel': 'linear', 'C': 1.0, 'gamma': 0.001} | \n",
" 0.973333 | \n",
" [0.966666666667, 0.966666666667, 0.93333333333... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" parameters mean_validation_score \\\n",
"0 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0001} 0.973333 \n",
"1 {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0001} 0.920000 \n",
"2 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.0003... 0.973333 \n",
"3 {'kernel': 'rbf', 'C': 1.0, 'gamma': 0.0003162... 0.920000 \n",
"4 {'kernel': 'linear', 'C': 1.0, 'gamma': 0.001} 0.973333 \n",
"\n",
" cv_validation_scores \n",
"0 [0.966666666667, 0.966666666667, 0.93333333333... \n",
"1 [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] \n",
"2 [0.966666666667, 0.966666666667, 0.93333333333... \n",
"3 [0.9, 0.9, 0.9, 0.933333333333, 0.966666666667] \n",
"4 [0.966666666667, 0.966666666667, 0.93333333333... "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(grid_clf.grid_scores_).head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"ベストのモデルは別名が付いている"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"scrolled": true,
"slideshow": {
"slide_type": "-"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'kernel': 'rbf', 'C': 7.75, 'gamma': 0.03162277660168379}\n",
"0.986666666667\n"
]
}
],
"source": [
"print(grid_clf.best_params_)\n",
"print(grid_clf.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# ランダムフォレストの場合\n",
"> iris(4次元)でRF使わないですが..."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"scrolled": true,
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'oob_score': False, 'max_features': 'auto', 'criterion': 'gini', 'n_jobs': 1, 'max_depth': None, 'min_samples_split': 2, 'verbose': 0, 'min_samples_leaf': 1, 'bootstrap': True, 'max_leaf_nodes': None, 'class_weight': None, 'warm_start': False, 'min_weight_fraction_leaf': 0.0, 'random_state': None, 'n_estimators': 10}\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier as RFC\n",
"clf_rf=RFC()\n",
"print(clf_rf.get_params())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"scrolled": true,
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5, error_score='raise',\n",
" estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
" max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n",
" oob_score=False, random_state=None, verbose=0,\n",
" warm_start=False),\n",
" fit_params={}, iid=True, n_jobs=-1,\n",
" param_grid={'n_estimators': [1, 2, 4, 8, 16, 32, 64, 128, 256]},\n",
" pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parameters_rf={'n_estimators':[1*2**i for i in range(9)]}\n",
"grid_clf_rf=grid.GridSearchCV(clf_rf,parameters_rf,n_jobs=-1,cv=5)\n",
"grid_clf_rf.fit(data_shuffle,target_shuffle)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false,
"scrolled": false,
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"
\n",
" \n",
" \n",
" | \n",
" parameters | \n",
" mean_validation_score | \n",
" cv_validation_scores | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" {'n_estimators': 1} | \n",
" 0.920000 | \n",
" [0.833333333333, 0.866666666667, 0.93333333333... | \n",
"
\n",
" \n",
" 1 | \n",
" {'n_estimators': 2} | \n",
" 0.906667 | \n",
" [0.9, 0.833333333333, 0.833333333333, 0.966666... | \n",
"
\n",
" \n",
" 2 | \n",
" {'n_estimators': 4} | \n",
" 0.933333 | \n",
" [0.9, 0.866666666667, 0.966666666667, 0.966666... | \n",
"
\n",
" \n",
" 3 | \n",
" {'n_estimators': 8} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
" 4 | \n",
" {'n_estimators': 16} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
" 5 | \n",
" {'n_estimators': 32} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
" 6 | \n",
" {'n_estimators': 64} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
" 7 | \n",
" {'n_estimators': 128} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
" 8 | \n",
" {'n_estimators': 256} | \n",
" 0.953333 | \n",
" [0.933333333333, 0.866666666667, 0.96666666666... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" parameters mean_validation_score \\\n",
"0 {'n_estimators': 1} 0.920000 \n",
"1 {'n_estimators': 2} 0.906667 \n",
"2 {'n_estimators': 4} 0.933333 \n",
"3 {'n_estimators': 8} 0.953333 \n",
"4 {'n_estimators': 16} 0.953333 \n",
"5 {'n_estimators': 32} 0.953333 \n",
"6 {'n_estimators': 64} 0.953333 \n",
"7 {'n_estimators': 128} 0.953333 \n",
"8 {'n_estimators': 256} 0.953333 \n",
"\n",
" cv_validation_scores \n",
"0 [0.833333333333, 0.866666666667, 0.93333333333... \n",
"1 [0.9, 0.833333333333, 0.833333333333, 0.966666... \n",
"2 [0.9, 0.866666666667, 0.966666666667, 0.966666... \n",
"3 [0.933333333333, 0.866666666667, 0.96666666666... \n",
"4 [0.933333333333, 0.866666666667, 0.96666666666... \n",
"5 [0.933333333333, 0.866666666667, 0.96666666666... \n",
"6 [0.933333333333, 0.866666666667, 0.96666666666... \n",
"7 [0.933333333333, 0.866666666667, 0.96666666666... \n",
"8 [0.933333333333, 0.866666666667, 0.96666666666... "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(grid_clf_rf.grid_scores_)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 終わり\n",
"@y__sama"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}