{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "이 노트북의 코드에 대한 설명은 [다중 평가 지표: cross_validate()](https://tensorflow.blog/2018/03/13/%EB%8B%A4%EC%A4%91-%ED%8F%89%EA%B0%80-%EC%A7%80%ED%91%9C-cross_validate/) 글을 참고하세요." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.datasets import load_digits\n", "from sklearn.model_selection import train_test_split, cross_val_score" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "digits = load_digits()\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " digits.data, digits.target == 9, random_state=42)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.svm import SVC" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.90200445, 0.90200445, 0.90200445])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(SVC(gamma='auto'), X_train, y_train, cv=3)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.90200445, 0.90200445, 0.90200445])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(SVC(gamma='auto'), X_train, y_train, scoring='accuracy', cv=3)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_validate" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'fit_time': array([0.03770995, 0.03589416, 0.03686881]),\n", " 'score_time': array([0.12240219, 0.11768389, 0.11690235]),\n", " 'test_accuracy': array([0.90200445, 0.90200445, 0.90200445]),\n", " 'train_accuracy': array([1., 1., 1.]),\n", " 'test_roc_auc': array([0.99657688, 0.99814815, 0.99943883]),\n", " 'train_roc_auc': array([1., 1., 1.])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_validate(SVC(gamma='auto'), X_train, y_train, \n", " scoring=['accuracy', 'roc_auc'], \n", " return_train_score=True, cv=3)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.90200445, 0.90200445, 0.90200445])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_validate(SVC(gamma='auto'), X_train, y_train, \n", " scoring=['accuracy'], cv=3,\n", " return_train_score=False)['test_accuracy']" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'fit_time': array([0.03640604, 0.03584003, 0.03449273]),\n", " 'score_time': array([0.11128712, 0.10693693, 0.11939406]),\n", " 'test_acc': array([0.90200445, 0.90200445, 0.90200445]),\n", " 'test_ra': array([0.99657688, 0.99814815, 0.99943883])}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_validate(SVC(gamma='auto'), X_train, y_train, \n", " scoring={'acc':'accuracy', 'ra':'roc_auc'}, \n", " return_train_score=False, cv=3)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "param_grid = {'gamma': [0.0001, 0.01, 0.1, 1, 10]}" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=3, estimator=SVC(),\n", " param_grid={'gamma': [0.0001, 0.01, 0.1, 1, 10]}, refit='accuracy',\n", " return_train_score=True, scoring=['accuracy'])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid = GridSearchCV(SVC(), param_grid=param_grid, \n", " scoring=['accuracy'], refit='accuracy',\n", " return_train_score=True, cv=3)\n", "grid.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'gamma': 0.0001}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid.best_params_" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9651076466221232" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid.best_score_" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
mean_fit_time0.0067960.0337280.0368650.0291520.028625
std_fit_time0.0000760.0007630.0002550.0003030.000275
mean_score_time0.0127030.0596570.0602910.0545080.054504
std_score_time0.0010030.0005450.0009150.0034510.001073
param_gamma0.00010.010.1110
params{'gamma': 0.0001}{'gamma': 0.01}{'gamma': 0.1}{'gamma': 1}{'gamma': 10}
split0_test_accuracy0.9665920.9020040.9020040.9020040.902004
split1_test_accuracy0.968820.9020040.9020040.9020040.902004
split2_test_accuracy0.9599110.9020040.9020040.9020040.902004
mean_test_accuracy0.9651080.9020040.9020040.9020040.902004
std_test_accuracy0.0037850.00.00.00.0
rank_test_accuracy12222
split0_train_accuracy0.9755011.01.01.01.0
split1_train_accuracy0.9621381.01.01.01.0
split2_train_accuracy0.9743881.01.01.01.0
mean_train_accuracy0.9706761.01.01.01.0
std_train_accuracy0.0060540.00.00.00.0
\n", "
" ], "text/plain": [ " 0 1 2 \\\n", "mean_fit_time 0.006796 0.033728 0.036865 \n", "std_fit_time 0.000076 0.000763 0.000255 \n", "mean_score_time 0.012703 0.059657 0.060291 \n", "std_score_time 0.001003 0.000545 0.000915 \n", "param_gamma 0.0001 0.01 0.1 \n", "params {'gamma': 0.0001} {'gamma': 0.01} {'gamma': 0.1} \n", "split0_test_accuracy 0.966592 0.902004 0.902004 \n", "split1_test_accuracy 0.96882 0.902004 0.902004 \n", "split2_test_accuracy 0.959911 0.902004 0.902004 \n", "mean_test_accuracy 0.965108 0.902004 0.902004 \n", "std_test_accuracy 0.003785 0.0 0.0 \n", "rank_test_accuracy 1 2 2 \n", "split0_train_accuracy 0.975501 1.0 1.0 \n", "split1_train_accuracy 0.962138 1.0 1.0 \n", "split2_train_accuracy 0.974388 1.0 1.0 \n", "mean_train_accuracy 0.970676 1.0 1.0 \n", "std_train_accuracy 0.006054 0.0 0.0 \n", "\n", " 3 4 \n", "mean_fit_time 0.029152 0.028625 \n", "std_fit_time 0.000303 0.000275 \n", "mean_score_time 0.054508 0.054504 \n", "std_score_time 0.003451 0.001073 \n", "param_gamma 1 10 \n", "params {'gamma': 1} {'gamma': 10} \n", "split0_test_accuracy 0.902004 0.902004 \n", "split1_test_accuracy 0.902004 0.902004 \n", "split2_test_accuracy 0.902004 0.902004 \n", "mean_test_accuracy 0.902004 0.902004 \n", "std_test_accuracy 0.0 0.0 \n", "rank_test_accuracy 2 2 \n", "split0_train_accuracy 1.0 1.0 \n", "split1_train_accuracy 1.0 1.0 \n", "split2_train_accuracy 1.0 1.0 \n", "mean_train_accuracy 1.0 1.0 \n", "std_train_accuracy 0.0 0.0 " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.transpose(pd.DataFrame(grid.cv_results_))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=3, estimator=SVC(),\n", " param_grid={'gamma': [0.0001, 0.01, 0.1, 1, 10]}, refit='ra',\n", " return_train_score=True,\n", " scoring={'acc': 'accuracy', 'ra': 'roc_auc'})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid = GridSearchCV(SVC(), param_grid=param_grid, \n", " scoring={'acc':'accuracy', 'ra':'roc_auc'}, refit='ra',\n", " return_train_score=True, cv=3)\n", "grid.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'gamma': 0.01}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid.best_params_" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9983352038907594" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid.best_score_" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
mean_fit_time0.0068640.0342730.0370540.0306540.028933
std_fit_time0.0001720.0002510.0004930.0005550.000616
mean_score_time0.0244930.1146140.1213660.113590.107006
std_score_time0.0025590.0013320.0032620.0033610.00633
param_gamma0.00010.010.1110
params{'gamma': 0.0001}{'gamma': 0.01}{'gamma': 0.1}{'gamma': 1}{'gamma': 10}
split0_test_acc0.9665920.9020040.9020040.9020040.902004
split1_test_acc0.968820.9020040.9020040.9020040.902004
split2_test_acc0.9599110.9020040.9020040.9020040.902004
mean_test_acc0.9651080.9020040.9020040.9020040.902004
std_test_acc0.0037850.00.00.00.0
rank_test_acc12222
split0_train_acc0.9755011.01.01.01.0
split1_train_acc0.9621381.01.01.01.0
split2_train_acc0.9743881.01.01.01.0
mean_train_acc0.9706761.01.01.01.0
std_train_acc0.0060540.00.00.00.0
split0_test_ra0.983670.9974190.9340070.50.5
split1_test_ra0.9871490.9981480.9124580.50.5
split2_test_ra0.9943880.9994390.9104940.50.5
mean_test_ra0.9884030.9983350.9189860.50.5
std_test_ra0.0044650.0008350.0106510.00.0
rank_test_ra21344
split0_train_ra0.9920171.01.01.01.0
split1_train_ra0.9949351.01.01.01.0
split2_train_ra0.989451.01.01.01.0
mean_train_ra0.9921341.01.01.01.0
std_train_ra0.0022410.00.00.00.0
\n", "
" ], "text/plain": [ " 0 1 2 \\\n", "mean_fit_time 0.006864 0.034273 0.037054 \n", "std_fit_time 0.000172 0.000251 0.000493 \n", "mean_score_time 0.024493 0.114614 0.121366 \n", "std_score_time 0.002559 0.001332 0.003262 \n", "param_gamma 0.0001 0.01 0.1 \n", "params {'gamma': 0.0001} {'gamma': 0.01} {'gamma': 0.1} \n", "split0_test_acc 0.966592 0.902004 0.902004 \n", "split1_test_acc 0.96882 0.902004 0.902004 \n", "split2_test_acc 0.959911 0.902004 0.902004 \n", "mean_test_acc 0.965108 0.902004 0.902004 \n", "std_test_acc 0.003785 0.0 0.0 \n", "rank_test_acc 1 2 2 \n", "split0_train_acc 0.975501 1.0 1.0 \n", "split1_train_acc 0.962138 1.0 1.0 \n", "split2_train_acc 0.974388 1.0 1.0 \n", "mean_train_acc 0.970676 1.0 1.0 \n", "std_train_acc 0.006054 0.0 0.0 \n", "split0_test_ra 0.98367 0.997419 0.934007 \n", "split1_test_ra 0.987149 0.998148 0.912458 \n", "split2_test_ra 0.994388 0.999439 0.910494 \n", "mean_test_ra 0.988403 0.998335 0.918986 \n", "std_test_ra 0.004465 0.000835 0.010651 \n", "rank_test_ra 2 1 3 \n", "split0_train_ra 0.992017 1.0 1.0 \n", "split1_train_ra 0.994935 1.0 1.0 \n", "split2_train_ra 0.98945 1.0 1.0 \n", "mean_train_ra 0.992134 1.0 1.0 \n", "std_train_ra 0.002241 0.0 0.0 \n", "\n", " 3 4 \n", "mean_fit_time 0.030654 0.028933 \n", "std_fit_time 0.000555 0.000616 \n", "mean_score_time 0.11359 0.107006 \n", "std_score_time 0.003361 0.00633 \n", "param_gamma 1 10 \n", "params {'gamma': 1} {'gamma': 10} \n", "split0_test_acc 0.902004 0.902004 \n", "split1_test_acc 0.902004 0.902004 \n", "split2_test_acc 0.902004 0.902004 \n", "mean_test_acc 0.902004 0.902004 \n", "std_test_acc 0.0 0.0 \n", "rank_test_acc 2 2 \n", "split0_train_acc 1.0 1.0 \n", "split1_train_acc 1.0 1.0 \n", "split2_train_acc 1.0 1.0 \n", "mean_train_acc 1.0 1.0 \n", "std_train_acc 0.0 0.0 \n", "split0_test_ra 0.5 0.5 \n", "split1_test_ra 0.5 0.5 \n", "split2_test_ra 0.5 0.5 \n", "mean_test_ra 0.5 0.5 \n", "std_test_ra 0.0 0.0 \n", "rank_test_ra 4 4 \n", "split0_train_ra 1.0 1.0 \n", "split1_train_ra 1.0 1.0 \n", "split2_train_ra 1.0 1.0 \n", "mean_train_ra 1.0 1.0 \n", "std_train_ra 0.0 0.0 " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.transpose(pd.DataFrame(grid.cv_results_))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SVC(gamma=0.01)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid.best_estimator_" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }