{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rOafwFRXG7fb" }, "source": [ "# 교차 검증과 그리드 서치" ] }, { "cell_type": "markdown", "metadata": { "id": "6YuimmymG7fi" }, "source": [ "\n", " \n", "
\n", " 구글 코랩에서 실행하기\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "dVNF7yZjyvoO" }, "source": [ "## 검증 세트" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "banlvMA6RfnM" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "wine = pd.read_csv('https://bit.ly/wine_csv_data')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "abR6QA7qRoKl" }, "outputs": [], "source": [ "data = wine[['alcohol', 'sugar', 'pH']].to_numpy()\n", "target = wine['class'].to_numpy()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "auLnVXyMRoeb" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_input, test_input, train_target, test_target = train_test_split(\n", " data, target, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "E-yV4cCXRqNK" }, "outputs": [], "source": [ "sub_input, val_input, sub_target, val_target = train_test_split(\n", " train_input, train_target, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "k29hKbw4R7Ki", "outputId": "17fb7b68-2b4a-4c0d-a026-5b8af9b79ea6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(4157, 3) (1040, 3)\n" ] } ], "source": [ "print(sub_input.shape, val_input.shape)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4VQz-UZ2SeLq", "outputId": "c02a1d4d-a2a4-4ec4-f1e3-425bbc976e44" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.9971133028626413\n", "0.864423076923077\n" ] } ], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "\n", "dt = DecisionTreeClassifier(random_state=42)\n", "dt.fit(sub_input, sub_target)\n", "\n", "print(dt.score(sub_input, sub_target))\n", "print(dt.score(val_input, val_target))" ] }, { "cell_type": "markdown", "metadata": { "id": "Z4gRXnK6y2Pt" }, "source": [ "## 교차 검증" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_J3LId-vSmNH", "outputId": "5f49522d-6e1b-42b3-9c5c-ccd068038dec" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'fit_time': array([0.00931716, 0.00749564, 0.00773239, 0.00731683, 0.00710797]), 'score_time': array([0.00109315, 0.00111032, 0.00101209, 0.00106931, 0.00115085]), 'test_score': array([0.86923077, 0.84615385, 0.87680462, 0.84889317, 0.83541867])}\n" ] } ], "source": [ "from sklearn.model_selection import cross_validate\n", "\n", "scores = cross_validate(dt, train_input, train_target)\n", "print(scores)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Yp3aagOoTHsO", "outputId": "cebb0e68-1b94-4c90-84ab-35156c4fa01c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.855300214703487\n" ] } ], "source": [ "import numpy as np\n", "\n", "print(np.mean(scores['test_score']))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0tQyaG0576Vn", "outputId": "0835bd01-4940-4093-9576-c1d3b8d6e40a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.855300214703487\n" ] } ], "source": [ "from sklearn.model_selection import StratifiedKFold\n", "\n", "scores = cross_validate(dt, train_input, train_target, cv=StratifiedKFold())\n", "print(np.mean(scores['test_score']))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1BmP_OTT_agM", "outputId": "16610b69-fdd4-424d-a18a-17e0851b8171" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.8574181117533719\n" ] } ], "source": [ "splitter = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)\n", "scores = cross_validate(dt, train_input, train_target, cv=splitter)\n", "print(np.mean(scores['test_score']))" ] }, { "cell_type": "markdown", "metadata": { "id": "Q21W8RsqDsDV" }, "source": [ "## 하이퍼파라미터 튜닝" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "S8pqss8onjR5" }, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "params = {'min_impurity_decrease': [0.0001, 0.0002, 0.0003, 0.0004, 0.0005]}" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "79MymJqxTu0P" }, "outputs": [], "source": [ "gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 117 }, "id": "tKAlTabkU-Lz", "outputId": "4125edd4-08d4-4932-ad11-2108550de929" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,\n", " param_grid={'min_impurity_decrease': [0.0001, 0.0002, 0.0003,\n", " 0.0004, 0.0005]})" ], "text/html": [ "
GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,\n",
              "             param_grid={'min_impurity_decrease': [0.0001, 0.0002, 0.0003,\n",
              "                                                   0.0004, 0.0005]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 13 } ], "source": [ "gs.fit(train_input, train_target)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q6iX3vH-VeEb", "outputId": "b68ad258-4e2a-4e62-f882-9ebea921942e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.9615162593804117\n" ] } ], "source": [ "dt = gs.best_estimator_\n", "print(dt.score(train_input, train_target))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lIzod3BwVHq-", "outputId": "fc1ea12a-2a33-4ebe-8438-2c9991e6da1a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'min_impurity_decrease': 0.0001}\n" ] } ], "source": [ "print(gs.best_params_)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0xfQswiui4Tr", "outputId": "7a74abdb-a237-49a0-d8cd-fe97fa9808bd" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[0.86819297 0.86453617 0.86492226 0.86780891 0.86761605]\n" ] } ], "source": [ "print(gs.cv_results_['mean_test_score'])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Rwg2aSyEVO17", "outputId": "b358e38f-22e4-4a1b-e688-866cd78398fc" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'min_impurity_decrease': 0.0001}\n" ] } ], "source": [ "best_index = np.argmax(gs.cv_results_['mean_test_score'])\n", "print(gs.cv_results_['params'][best_index])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "8jHxZ7XmVU11" }, "outputs": [], "source": [ "params = {'min_impurity_decrease': np.arange(0.0001, 0.001, 0.0001),\n", " 'max_depth': range(5, 20, 1),\n", " 'min_samples_split': range(2, 100, 10)\n", " }" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 117 }, "id": "KnP3GA6MVsVH", "outputId": "7540ba6a-bda4-4c9e-beca-484dd8ef540f" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,\n", " param_grid={'max_depth': range(5, 20),\n", " 'min_impurity_decrease': array([0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008,\n", " 0.0009]),\n", " 'min_samples_split': range(2, 100, 10)})" ], "text/html": [ "
GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,\n",
              "             param_grid={'max_depth': range(5, 20),\n",
              "                         'min_impurity_decrease': array([0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008,\n",
              "       0.0009]),\n",
              "                         'min_samples_split': range(2, 100, 10)})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 19 } ], "source": [ "gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params, n_jobs=-1)\n", "gs.fit(train_input, train_target)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qi9-O_VGV0Ho", "outputId": "c3eae726-486f-4a1d-f4b4-7074715fd6b4" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'max_depth': 14, 'min_impurity_decrease': 0.0004, 'min_samples_split': 12}\n" ] } ], "source": [ "print(gs.best_params_)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZnJjLATAV2Sq", "outputId": "81351bd9-ec0b-416f-963a-f07fbbf6a0d9" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.8683865773302731\n" ] } ], "source": [ "print(np.max(gs.cv_results_['mean_test_score']))" ] }, { "cell_type": "markdown", "metadata": { "id": "d0k9DQTNlaD6" }, "source": [ "### 랜덤 서치" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "_T9KTEk1GBcY" }, "outputs": [], "source": [ "from scipy.stats import uniform, randint" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fd0UJpCGGDhz", "outputId": "d5472c8c-2bfc-4ee0-aae5-e2b613f1f4b0" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([4, 7, 6, 8, 9, 3, 8, 3, 1, 4])" ] }, "metadata": {}, "execution_count": 23 } ], "source": [ "rgen = randint(0, 10)\n", "rgen.rvs(10)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ch3zTUohIJR6", "outputId": "95ae5d40-2edb-43b5-c01f-40d16eb517b1" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),\n", " array([116, 105, 95, 100, 84, 90, 97, 95, 107, 111]))" ] }, "metadata": {}, "execution_count": 24 } ], "source": [ "np.unique(rgen.rvs(1000), return_counts=True)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bGhshTn0IjkI", "outputId": "bd9ff812-9b11-4b1e-dd9a-9f9a7010469d" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0.07156624, 0.51330724, 0.78244744, 0.14237963, 0.05055468,\n", " 0.13124955, 0.15801332, 0.99110938, 0.08459786, 0.92447632])" ] }, "metadata": {}, "execution_count": 25 } ], "source": [ "ugen = uniform(0, 1)\n", "ugen.rvs(10)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "irDX9e6WYTIH" }, "outputs": [], "source": [ "params = {'min_impurity_decrease': uniform(0.0001, 0.001),\n", " 'max_depth': randint(20, 50),\n", " 'min_samples_split': randint(2, 25),\n", " 'min_samples_leaf': randint(1, 25),\n", " }" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 117 }, "id": "Wc4OIingWQCK", "outputId": "f1783475-8d2e-4cb3-845f-643fcc517ed5" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42),\n", " n_iter=100, n_jobs=-1,\n", " param_distributions={'max_depth': ,\n", " 'min_impurity_decrease': ,\n", " 'min_samples_leaf': ,\n", " 'min_samples_split': },\n", " random_state=42)" ], "text/html": [ "
RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42),\n",
              "                   n_iter=100, n_jobs=-1,\n",
              "                   param_distributions={'max_depth': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce351cc0>,\n",
              "                                        'min_impurity_decrease': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7cccce2f4610>,\n",
              "                                        'min_samples_leaf': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce352da0>,\n",
              "                                        'min_samples_split': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce353bb0>},\n",
              "                   random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 27 } ], "source": [ "from sklearn.model_selection import RandomizedSearchCV\n", "\n", "gs = RandomizedSearchCV(DecisionTreeClassifier(random_state=42), params,\n", " n_iter=100, n_jobs=-1, random_state=42)\n", "gs.fit(train_input, train_target)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p7IbsGH3ZSv-", "outputId": "b2f963f8-7f7a-4af0-98aa-20eac00d265a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'max_depth': 39, 'min_impurity_decrease': 0.00034102546602601173, 'min_samples_leaf': 7, 'min_samples_split': 13}\n" ] } ], "source": [ "print(gs.best_params_)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dYI3HwMQbtnr", "outputId": "2cb232e6-28d8-42b0-db66-bfa0bc58832d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.8695428296438884\n" ] } ], "source": [ "print(np.max(gs.cv_results_['mean_test_score']))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3QV7yRpidByf", "outputId": "f449fcd7-ad05-45b8-a86e-7f03bc565937" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.86\n" ] } ], "source": [ "dt = gs.best_estimator_\n", "\n", "print(dt.score(test_input, test_target))" ] }, { "cell_type": "markdown", "metadata": { "id": "cA42IsMdhgE7" }, "source": [ "## 확인문제" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 117 }, "id": "8qxg36iThiUm", "outputId": "f2c84c8b-a07f-4b61-a708-a869c685b5ed" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42,\n", " splitter='random'),\n", " n_iter=100, n_jobs=-1,\n", " param_distributions={'max_depth': ,\n", " 'min_impurity_decrease': ,\n", " 'min_samples_leaf': ,\n", " 'min_samples_split': },\n", " random_state=42)" ], "text/html": [ "
RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42,\n",
              "                                                    splitter='random'),\n",
              "                   n_iter=100, n_jobs=-1,\n",
              "                   param_distributions={'max_depth': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce351cc0>,\n",
              "                                        'min_impurity_decrease': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7cccce2f4610>,\n",
              "                                        'min_samples_leaf': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce352da0>,\n",
              "                                        'min_samples_split': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7cccce353bb0>},\n",
              "                   random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 31 } ], "source": [ "gs = RandomizedSearchCV(DecisionTreeClassifier(splitter='random', random_state=42), params,\n", " n_iter=100, n_jobs=-1, random_state=42)\n", "gs.fit(train_input, train_target)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CMZ4UE8ihqwg", "outputId": "5e7389ae-bf93-404e-eb2f-04f50af7da60" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'max_depth': 43, 'min_impurity_decrease': 0.00011407982271508446, 'min_samples_leaf': 19, 'min_samples_split': 18}\n", "0.8458726956392981\n", "0.786923076923077\n" ] } ], "source": [ "print(gs.best_params_)\n", "print(np.max(gs.cv_results_['mean_test_score']))\n", "\n", "dt = gs.best_estimator_\n", "print(dt.score(test_input, test_target))" ] } ], "metadata": { "colab": { "name": "5-2 교차 검증과 그리드 서치.ipynb", "provenance": [] }, "kernelspec": { "display_name": "default:Python", "language": "python", "name": "conda-env-default-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" } }, "nbformat": 4, "nbformat_minor": 0 }