{ "cells": [ { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The watermark extension is already loaded. To reload it, use:\n", " %reload_ext watermark\n", "CPython 3.5.4\n", "IPython 6.2.1\n", "\n", "numpy 1.14.0\n", "scipy 1.0.0\n", "sklearn 0.20.0\n", "pandas 0.22.0\n", "matplotlib 2.1.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/sklearn/utils/deprecation.py:77: DeprecationWarning: Function fetch_mldata is deprecated; fetch_mldata was deprecated in version 0.20 and will be removed in version 0.22\n", " warnings.warn(msg, category=DeprecationWarning)\n", "/anaconda3/envs/mlbook/lib/python3.5/site-packages/sklearn/utils/deprecation.py:77: DeprecationWarning: Function mldata_filename is deprecated; mldata_filename was deprecated in version 0.20 and will be removed in version 0.22\n", " warnings.warn(msg, category=DeprecationWarning)\n" ] }, { "data": { "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%load_ext watermark\n", "%watermark -v -p numpy,scipy,sklearn,pandas,matplotlib\n", "# 파이썬 2와 파이썬 3 지원\n", "from __future__ import division, print_function, unicode_literals\n", "\n", "# 공통\n", "import numpy as np\n", "import os\n", "\n", "# 일관된 출력을 위해 유사난수 초기화\n", "np.random.seed(42)\n", "\n", "# 맷플롯립 설정\n", "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "\n", "# 한글출력\n", "matplotlib.rc('font', family='NanumBarunGothic')\n", "plt.rcParams['axes.unicode_minus'] = False\n", "\n", "# 그림을 저장할 폴드\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"classification\"\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)\n", "\n", "from sklearn.datasets import fetch_mldata\n", "mnist = fetch_mldata('MNIST original')\n", "X, y = mnist[\"data\"], mnist[\"target\"]\n", "\n", "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "some_digit = X[36000]\n", "some_digit_image = some_digit.reshape(28, 28)\n", "# plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n", " #interpolation=\"nearest\")\n", "# plt.axis(\"off\")\n", "\n", "save_fig(\"some_digit_plot\")\n", "def plot_digit(data):\n", " image = data.reshape(28, 28)\n", " plt.imshow(image, cmap = matplotlib.cm.binary,\n", " interpolation=\"nearest\")\n", " plt.axis(\"off\")\n", "# 숫자 그림을 위한 추가 함수\n", "def plot_digits(instances, images_per_row=10, **options):\n", " size = 28\n", " images_per_row = min(len(instances), images_per_row)\n", " images = [instance.reshape(size,size) for instance in instances]\n", " n_rows = (len(instances) - 1) // images_per_row + 1\n", " row_images = []\n", " n_empty = n_rows * images_per_row - len(instances)\n", " images.append(np.zeros((size, size * n_empty)))\n", " for row in range(n_rows):\n", " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", " row_images.append(np.concatenate(rimages, axis=1))\n", " image = np.concatenate(row_images, axis=0)\n", " plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n", " plt.axis(\"off\")\n", "plt.figure(figsize=(9,9))\n", "example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n", "#plot_digits(example_images, images_per_row=10)\n", "#save_fig(\"more_digits_plot\")\n", "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]\n", "import numpy as np\n", "\n", "shuffle_index = np.random.permutation(60000)\n", "X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]\n", "y_train_5 = (y_train == 5)\n", "y_test_5 = (y_test == 5)\n", "from sklearn.linear_model import SGDClassifier\n", "\n", "sgd_clf = SGDClassifier(max_iter=5, random_state=42)\n", "sgd_clf.fit(X_train, y_train_5)\n", "sgd_clf.predict([some_digit])\n", "from sklearn.model_selection import cross_val_score\n", "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")\n", "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.base import clone\n", "\n", "skfolds = StratifiedKFold(n_splits=3, random_state=42)\n", "\n", "for train_index, test_index in skfolds.split(X_train, y_train_5):\n", " clone_clf = clone(sgd_clf)\n", " X_train_folds = X_train[train_index]\n", " y_train_folds = (y_train_5[train_index])\n", " X_test_fold = X_train[test_index]\n", " y_test_fold = (y_train_5[test_index])\n", "\n", " clone_clf.fit(X_train_folds, y_train_folds)\n", " y_pred = clone_clf.predict(X_test_fold)\n", " n_correct = sum(y_pred == y_test_fold)\n", " # print(n_correct / len(y_pred))\n", "from sklearn.base import BaseEstimator\n", "class Never5Classifier(BaseEstimator):\n", " def fit(self, X, y=None):\n", " pass\n", " def predict(self, X):\n", " return np.zeros((len(X), 1), dtype=bool)\n", "never_5_clf = Never5Classifier()\n", "cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")\n", "from sklearn.model_selection import cross_val_predict\n", "\n", "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)\n", "from sklearn.metrics import confusion_matrix\n", "\n", "confusion_matrix(y_train_5, y_train_pred)\n", "y_train_perfect_predictions = y_train_5\n", "\n", "confusion_matrix(y_train_5, y_train_perfect_predictions)\n", "from sklearn.metrics import precision_score, recall_score\n", "\n", "precision_score(y_train_5, y_train_pred)\n", "from sklearn.metrics import f1_score\n", "f1_score(y_train_5, y_train_pred)\n", "y_scores = sgd_clf.decision_function([some_digit])\n", "threshold = 0\n", "y_some_digit_pred = (y_scores > threshold)\n", "threshold = 200000\n", "y_some_digit_pred = (y_scores > threshold)\n", "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n", " method=\"decision_function\")\n", "from sklearn.metrics import precision_recall_curve\n", "\n", "precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)\n", "def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n", " plt.plot(thresholds, precisions[:-1], \"b--\", label=\"정밀도\", linewidth=2)\n", " plt.plot(thresholds, recalls[:-1], \"g-\", label=\"재현율\", linewidth=2)\n", " plt.xlabel(\"임계값\", fontsize=16)\n", " plt.legend(loc=\"upper left\", fontsize=16)\n", " plt.ylim([0, 1])\n", "\n", "y_train_pred_90 = (y_scores > 70000)\n", "\n", "def plot_precision_vs_recall(precisions, recalls):\n", " plt.plot(recalls, precisions, \"b-\", linewidth=2)\n", " plt.xlabel(\"재현율\", fontsize=16)\n", " plt.ylabel(\"정밀도\", fontsize=16)\n", " plt.axis([0, 1, 0, 1])\n", "from sklearn.metrics import roc_curve\n", "\n", "fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)\n", "def plot_roc_curve(fpr, tpr, label=None):\n", " plt.plot(fpr, tpr, linewidth=2, label=label)\n", " plt.plot([0, 1], [0, 1], 'k--')\n", " plt.axis([0, 1, 0, 1])\n", " plt.xlabel('거짓 양성 비율', fontsize=16)\n", " plt.ylabel('진짜 양성 비율', fontsize=16)\n", "from sklearn.metrics import roc_auc_score\n", "\n", "roc_auc_score(y_train_5, y_scores)\n", "from sklearn.ensemble import RandomForestClassifier\n", "forest_clf = RandomForestClassifier(n_estimators=10, random_state=42)\n", "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n", " method=\"predict_proba\")\n", "y_scores_forest = y_probas_forest[:, 1] # 점수는 양상 클래스의 확률입니다\n", "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)\n", "y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Chapter 3. 분류\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.4 다중 분류" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 이진 분류기 : 두 개의 클래스를 구별
\n", "- **다중 분류기**multiclass classifier(또는 **다항 분류기**multinomial classifier) : 둘 이상의 클래스를 구별
\n", " - **일대다**one-versus-all, one-versus-the-rest(OvA) 전략 : 이진 분류기를 여러 개 사용해 다중 클래스를 분류하는 기법. 이는 이미지를 분류할 때 각 분류기의 결정 점수 중에서 가장 높은 것을 클래스로 선택
\n", " - **일대일**one-versus-one(OvO) 전략 : 0과 1 구별, 0과 2 구별, 1과 2 구별 등과 같이 각 숫자의 조합마다 이진 분류기를 훈련. 클래스가 N개라면 분류기는 N*(N-1)/2개가 필요" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다중 클래스 분류 작업에 이진 분류 알고리즘을 선택하면 사이킷런이 자동으로 OvA(SVM 분류기일 때는 OvO)를 적용
\n", "SGDClassifier를 적용해보겠습니다." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([5.])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sgd_clf.fit(X_train, y_train)\n", "sgd_clf.predict([some_digit])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 코드는 5를 구별한 타깃 클래스(y_train_5) 대신 0에서 9까지의 원래 타깃 클래스(y_train)를 사용
\n", "내부에서는 사이킷런이 실제로 10개의 이진 분류기를 훈련시키고 각각의 결정 점수를 얻어 점수가 가장 높은 클래스를 선택
\n", "이를 확인하기 위해 decision_function() 메서드를 호출" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-311402.62954431, -363517.28355739, -446449.5306454 ,\n", " -183226.61023518, -414337.15339485, 161855.74572176,\n", " -452576.39616343, -471957.14962573, -518542.33997148,\n", " -536774.63961222]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "some_digit_scores = sgd_clf.decision_function([some_digit])\n", "some_digit_scores" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(some_digit_scores)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sgd_clf.classes_" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sgd_clf.classes_[np.argmax(some_digit_scores)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 사이킷런에서 OvO나 OvA을 사용하도록 강제하려면 OneVsOneClassifier나 OneVsRestClassifier를 사용" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([5.])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.multiclass import OneVsOneClassifier\n", "ovo_clf = OneVsOneClassifier(SGDClassifier(max_iter=5, random_state=42))\n", "ovo_clf.fit(X_train, y_train)\n", "ovo_clf.predict([some_digit])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "45" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(ovo_clf.estimators_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 아래 코드는 RandomForestClassifier를 훈련" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([5.])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forest_clf.fit(X_train, y_train)\n", "forest_clf.predict([some_digit])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 랜덤 포레스트 분류기는 직접 샘플을 다중 클래스로 분류할 수 있기 때문에 OvA나 OvO를 적용할 필요가 없음\n", "- predict_proba() 메서드를 호출하면 분류기가 각 샘플에 부여한 클래스별 확률을 얻을 수 있음" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.1, 0. , 0. , 0.1, 0. , 0.8, 0. , 0. , 0. , 0. ]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forest_clf.predict_proba([some_digit])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 이제 교차 검증을 사용하여 분류기를 평가함
\n", "cross_val_score() 함수를 사용해 SGDClassifier의 정확도를 평가 " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.84063187, 0.84899245, 0.86652998])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "분류기의 성능을 더 높이기 위해 입력의 스케일을 조정함" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.91011798, 0.90874544, 0.906636 ])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.preprocessing import StandardScaler\n", "scaler = StandardScaler()\n", "X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))\n", "cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring=\"accuracy\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.5 에러 분석" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 모델의 성능을 향상시킬 한 가지 방법은 만들어진 에러의 종류를 분석하는 것
\n", "먼저 오차 행렬을 살펴보기위해 cross_val_predict() 함수를 사용해 예측을 만들고 confusion_matrix() 함수를 호출함" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5725, 3, 24, 9, 10, 49, 50, 10, 39, 4],\n", " [ 2, 6493, 43, 25, 7, 40, 5, 10, 109, 8],\n", " [ 51, 41, 5321, 104, 89, 26, 87, 60, 166, 13],\n", " [ 47, 46, 141, 5342, 1, 231, 40, 50, 141, 92],\n", " [ 19, 29, 41, 10, 5366, 9, 56, 37, 86, 189],\n", " [ 73, 45, 36, 193, 64, 4582, 111, 30, 193, 94],\n", " [ 29, 34, 44, 2, 42, 85, 5627, 10, 45, 0],\n", " [ 25, 24, 74, 32, 54, 12, 6, 5787, 15, 236],\n", " [ 52, 161, 73, 156, 10, 163, 61, 25, 5027, 123],\n", " [ 43, 35, 26, 92, 178, 28, 2, 223, 82, 5240]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)\n", "conf_mx = confusion_matrix(y_train, y_train_pred)\n", "conf_mx #행은 실제 클래스, 열은 예측한 클래스" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "오차 행렬을 맷플롯립의 matshow() 함수를 사용해 이미지로 표현" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['NanumBarunGothic'] not found. Falling back to DejaVu Sans\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEFCAYAAAAsdjEBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAADA5JREFUeJzt3V+InXedx/H3J5nGxrrVhtZQTEikVFBpTbe5W/rnorK2sOzWChs2LLI3kRYFRS8t1BQveuMipnYJhCIqu/TCtUsUvfBCWIps0m1LsJVQt9akGExN1f4bk8l89+JMIcR0zjM6v3nm+Hu/YAgzffLjm9PznufMnOf8TqoKSf3ZMPYAksZh/FKnjF/qlPFLnTJ+qVPGL3XK+KVOjRp/ki1J/jPJ60leTPJPY84zTZJ3JDm0NOurSZ5KcufYcw2R5Pok80m+NfYsQyTZk+S5pfvGz5PcMvZMy0myM8n3k7yS5FSSA0nmxp5rOWOf+R8GzgJbgb3AI0k+PO5Iy5oDTgC3Ae8G7gceS7JzxJmGehg4MvYQQyT5KPAQ8C/AXwG3Av836lDTfR34NXAtsIvJfeS+USeaYrT4k1wB3APcX1WvVdV/A/8F/PNYM01TVa9X1QNV9YuqWqyqw8ALwM1jz7acJHuA3wI/GnuWgb4E7K+qnyzdzi9V1UtjDzXF+4HHqmq+qk4BPwDW84ls1DP/B4DzVXX8gq89wzq/wS6UZCuTf8dPx57l7SS5EtgPfH7sWYZIshHYDVyT5PkkJ5ceQm8ee7YpvgrsSfLOJO8D7mTyDWDdGjP+dwG/u+hrv2PyMG/dS3IZ8G3gG1X1s7HnWcaDwKGqOjH2IANtBS4DPgHcwuQh9E3AF8ccaoAfMzlx/R44CRwFvjvqRFOMGf9rwJUXfe1K4NURZlmRJBuAbzL5fcWnRx7nbSXZBdwB/OvYs6zAm0t/fq2qflVVLwNfAe4acaZlLd0ffgh8B7gCuBq4isnvLdatMeM/Dswluf6Cr32EdfwQGiBJgENMzlD3VNW5kUdazu3ATuCXSU4BXwDuSfK/Yw61nKp6hcmZc5ZebroF2A4cqKo/VNVvgEdZx9+wYMT4q+p1Jt8p9ye5IsnfAH/P5Iy6nj0CfBD4u6p6c9rBIzsIXMfkofMu4N+A7wF/O+ZQAzwKfCbJe5NcBXwWODzyTG9r6dHJC8C9SeaSvAf4JJPfYa1bYz/Vdx+wmclTJP8O3FtV6/bMn2QH8CkmIZ1K8trSx96RR7ukqnqjqk699cHkR635qjo99mxTPMjkacnjwHPAU8CXR51ouo8DHwNOA88DC8DnRp1oiriZh9Snsc/8kkZi/FKnjF/qlPFLnTJ+qVPGL3VqXcSfZN/YM6zUrM08a/OCM7e2LuIHZuYGu8CszTxr84IzN7Ve4pe0xppd4bdly5batm3boGPPnDnDli1bBh177NixP2csqQtVlWnHNNtjbNu2bRw+vPqvxdixY8eqr6k/Nnnx4mxpdSJreVuMeXm9D/ulThm/1Cnjlzpl/FKnjF/q1KD4Z+2ddSRNN/SpvgvfWWcX8L0kz6znLbckLW/qmX8W31lH0nRDHvbP/DvrSPpjQ+If/M46SfYlOZrk6JkzZ1ZjPkmNDIl/8DvrVNXBqtpdVbuHXqsvaRxD4p/Jd9aRtLyp8c/wO+tIWsbQi3xm6p11JE036Hn+qjoD/EPjWSStIS/vlTpl/FKnjF/qlPFLnWq2gWeSJgu33PNsw4Y23wtn8W3QW+1bN4u3xdxcs60uWVhYaLLukA08PfNLnTJ+qVPGL3XK+KVOGb/UKeOXOmX8UqeMX+qU8UudMn6pU8Yvdcr4pU4Zv9Qp45c6ZfxSp4xf6pTxS50yfqlTxi91yvilThm/1CnjlzrVbk9i2myF3Wp7bYCnn366ybo333xzk3Wh3VbYi4uLTdbduHFjk3Wh3W3R8j43pr/Mf5WkqYxf6pTxS50yfqlTxi91yvilThm/1Kmp8Sd5R5JDSV5M8mqSp5LcuRbDSWpnyJl/DjgB3Aa8G7gfeCzJznZjSWpt6hV+VfU68MAFXzqc5AXgZuAXbcaS1NqKf+ZPshX4APDT1R9H0lpZ0bX9SS4Dvg18o6p+don/vg/Yt0qzSWpocPxJNgDfBM4Cn77UMVV1EDi4dHybV1lIWhWD4k8S4BCwFbirqs41nUpSc0PP/I8AHwTuqKo3G84jaY0MeZ5/B/ApYBdwKslrSx97m08nqZkhT/W9CGQNZpG0hry8V+qU8UudMn6pU8YvdSqtdjxNUpPLA2bH3FybzYyffPLJJusC3HjjjU3W3bx5c5N15+fnm6zbUqv7BbTZJfn8+fNU1dT4PPNLnTJ+qVPGL3XK+KVOGb/UKeOXOmX8UqeMX+qU8UudMn6pU8Yvdcr4pU4Zv9Qp45c6ZfxSp4xf6pTxS50yfqlTxi91yvilThm/1CnjlzrVdOvuRuu2WLapVrcxwLFjx5qse8MNNzRZd8OGduebVrdzy5lbbAt+9uxZFhcX3bpb0qUZv9Qp45c6ZfxSp4xf6pTxS50yfqlTK4o/yfVJ5pN8q9VAktbGSs/8DwNHWgwiaW0Njj/JHuC3wI/ajSNprQyKP8mVwH7g823HkbRWhl5Y/CBwqKpOLHdtfZJ9wL7VGExSW1PjT7ILuAO4adqxVXUQOLj099q9mkXSn23Imf92YCfwy6Wz/ruAjUk+VFV/3W40SS0Nif8g8B8XfP4FJt8M7m0xkKS1MTX+qnoDeOOtz5O8BsxX1emWg0lqa8U7CVTVAw3mkLTGvLxX6pTxS50yfqlTxi91qunuvS12PW25E24rmzZtarb2uXPnmqz7+OOPN1n37rvvbrIuwPnz55us2/L/38LCwqqvef78earK3XslXZrxS50yfqlTxi91yvilThm/1Cnjlzpl/FKnjF/qlPFLnTJ+qVPGL3XK+KVOGb/UKeOXOmX8UqeMX+qU8UudMn6pU8Yvdcr4pU413b136V19u9dyx+FWt3GLnZcBjh8/3mRdgOuuu67Jui3vx63uG+7eK+ltGb/UKeOXOmX8UqeMX+qU8UudMn6pU4PjT7InyXNJXk/y8yS3tBxMUltzQw5K8lHgIeAfgf8Brm05lKT2BsUPfAnYX1U/Wfr8pUbzSFojUx/2J9kI7AauSfJ8kpNJDiTZ3H48Sa0M+Zl/K3AZ8AngFmAXcBPwxYsPTLIvydEkR1d1Skmrbkj8by79+bWq+lVVvQx8Bbjr4gOr6mBV7a6q3as5pKTVNzX+qnoFOAm0e2mapDU39Km+R4HPJHlvkquAzwKH240lqbWhv+1/ELgaOA7MA48BX241lKT2BsVfVeeA+5Y+JP0F8PJeqVPGL3XK+KVOGb/UKeOXOtV06+4mCzfUarvqWdy6e3Fxscm6LZ04caLJutu3b2+yLsDmzav/Epn5+XkWFxfdulvSpRm/1Cnjlzpl/FKnjF/qlPFLnTJ+qVPGL3XK+KVOGb/UKeOXOmX8UqeMX+qU8UudMn6pU8Yvdcr4pU4Zv9Qp45c6ZfxSp4xf6lTT3Xtb7IY7Nzf0vUVXbmFhocm6LWc+e/Zsk3U3bdrUZN1WtzG023H4iSeeaLIuwK233rrqay4sLFBV7t4r6dKMX+qU8UudMn6pU8Yvdcr4pU4Zv9SpQfEn2Znk+0leSXIqyYEk7Z68ltTc0DP/14FfA9cCu4DbgPtaDSWpvaHxvx94rKrmq+oU8APgw+3GktTa0Pi/CuxJ8s4k7wPuZPINQNKMGhr/j5mc6X8PnASOAt+9+KAk+5IcTXJ09UaU1MLU+JNsAH4IfAe4ArgauAp46OJjq+pgVe2uqt2rPaik1TXkzL8F2A4cqKo/VNVvgEeBu5pOJqmpqfFX1cvAC8C9SeaSvAf4JPBM6+EktTP0Z/6PAx8DTgPPAwvA51oNJam9QRfqVNXTwO1tR5G0lry8V+qU8UudMn6pU8Yvdcr4pU413bo7mbp7cBdabGH+llbbgrfaErzV/Q3g8ssvb7LuuXPnmqwLcOTIkVVfc+/evTz77LNu3S3p0oxf6pTxS50yfqlTxi91yvilThm/1Cnjlzpl/FKnjF/qlPFLnTJ+qVPGL3XK+KVOGb/UKeOXOmX8UqeMX+qU8UudMn6pU8Yvdarl7r2ngRcHHn418HKTQdqZtZlnbV5w5j/Vjqq6ZtpBzeJfiSRHq2r32HOsxKzNPGvzgjO35sN+qVPGL3VqvcR/cOwB/gSzNvOszQvO3NS6+Jlf0tpbL2d+SWvM+KVOGb/UKeOXOmX8Uqf+H+BU3ATt8qFYAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.matshow(conf_mx, cmap=plt.cm.gray)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 오차 행렬은 대부분의 이미지가 올바르게 분류되었음을 나타내는 주대각선에 있으므로 매우 좋음" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그래프의 에러 부분에 초점을 맞추기 위해 오차 행렬의 각 값을 대응되는 클래스의 이미지 개수로 나누어 에러 비율을 비교(MNIST는 클래스별 이미지 개수가 동일하지 않음)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5923],\n", " [6742],\n", " [5958],\n", " [6131],\n", " [5842],\n", " [5421],\n", " [5918],\n", " [6265],\n", " [5851],\n", " [5949]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row_sums = conf_mx.sum(axis=1, keepdims=True) # column합\n", "row_sums" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[9.66570994e-01, 5.06500084e-04, 4.05200068e-03, 1.51950025e-03,\n", " 1.68833361e-03, 8.27283471e-03, 8.44166807e-03, 1.68833361e-03,\n", " 6.58450110e-03, 6.75333446e-04],\n", " [2.96647879e-04, 9.63067339e-01, 6.37792940e-03, 3.70809849e-03,\n", " 1.03826758e-03, 5.93295758e-03, 7.41619697e-04, 1.48323939e-03,\n", " 1.61673094e-02, 1.18659152e-03],\n", " [8.55991944e-03, 6.88150386e-03, 8.93084928e-01, 1.74555220e-02,\n", " 1.49378986e-02, 4.36388050e-03, 1.46022155e-02, 1.00704935e-02,\n", " 2.78616986e-02, 2.18194025e-03],\n", " [7.66595988e-03, 7.50285435e-03, 2.29978796e-02, 8.71309737e-01,\n", " 1.63105529e-04, 3.76773773e-02, 6.52422117e-03, 8.15527646e-03,\n", " 2.29978796e-02, 1.50057087e-02],\n", " [3.25231085e-03, 4.96405341e-03, 7.01814447e-03, 1.71174255e-03,\n", " 9.18521054e-01, 1.54056830e-03, 9.58575830e-03, 6.33344745e-03,\n", " 1.47209860e-02, 3.23519343e-02],\n", " [1.34661502e-02, 8.30105147e-03, 6.64084117e-03, 3.56022874e-02,\n", " 1.18059399e-02, 8.45231507e-01, 2.04759270e-02, 5.53403431e-03,\n", " 3.56022874e-02, 1.73399742e-02],\n", " [4.90030416e-03, 5.74518418e-03, 7.43494424e-03, 3.37952011e-04,\n", " 7.09699223e-03, 1.43629605e-02, 9.50827982e-01, 1.68976005e-03,\n", " 7.60392024e-03, 0.00000000e+00],\n", " [3.99042298e-03, 3.83080607e-03, 1.18116520e-02, 5.10774142e-03,\n", " 8.61931365e-03, 1.91540303e-03, 9.57701516e-04, 9.23703113e-01,\n", " 2.39425379e-03, 3.76695930e-02],\n", " [8.88736968e-03, 2.75166638e-02, 1.24764997e-02, 2.66621090e-02,\n", " 1.70910955e-03, 2.78584857e-02, 1.04255683e-02, 4.27277388e-03,\n", " 8.59169373e-01, 2.10220475e-02],\n", " [7.22810556e-03, 5.88334174e-03, 4.37048243e-03, 1.54647840e-02,\n", " 2.99209951e-02, 4.70667339e-03, 3.36190956e-04, 3.74852916e-02,\n", " 1.37838292e-02, 8.80820306e-01]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "norm_conf_mx = conf_mx / row_sums\n", "norm_conf_mx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다른 항목은 그대로 유지하고 주대각선만 0으로 채워서 그래프를 그림" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['NanumBarunGothic'] not found. Falling back to DejaVu Sans\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEFCAYAAAAsdjEBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAADUZJREFUeJzt3V+InfWZwPHvk2QSTLpRgxq0FBMWi26Jxu2AyOKqoGwNLKtVVFbWsCApSoXWFr2pYtVeeGGlqHUNhiBt2UWkithiLwoWelG2cf0TSkSSTdOojcaatE7+NMnk2YszWYLVnHfc85t3Zp/vB4Yw48vD4yTfec+cP++JzERSPfP6XkBSP4xfKsr4paKMXyrK+KWijF8qyvilonqNPyKWRcSzEbEvInZExD/3uc8wEbEoIjZM7fphRLwSEVf1vVcXEXFORByMiB/2vUsXEXFjRGyZ+rexLSIu6XunE4mIFRHx04jYExG7IuLRiFjQ914n0veZ/zHgELAcuAl4PCK+0O9KJ7QA2AlcCpwM3A08HREretypq8eAX/e9RBcRcSXwIPCvwF8Bfw/8d69LDfd94D3gTGA1g38jt/W60RC9xR8RS4BrgbszcyIzfwk8D/xLXzsNk5n7MvPezPxtZh7NzBeA7cAX+97tRCLiRmAv8PO+d+no28B9mfmrqe/z25n5dt9LDbESeDozD2bmLuBFYDafyHo9838emMzMN4/72mvM8m/Y8SJiOYP/j9/0vcsniYilwH3AN/repYuImA+MA6dHxNaIeGvqJvRJfe82xPeAGyNicUR8FriKwQ+AWavP+D8D/PEjX/sjg5t5s15EjAE/Ap7KzDf63ucE7gc2ZObOvhfpaDkwBlwHXMLgJvSFwLf6XKqDXzA4cf0JeAvYBDzX60ZD9Bn/BLD0I19bCnzYwy7TEhHzgB8wuL/iqz2v84kiYjVwBfBw37tMw4GpPx/JzN9n5vvAd4E1Pe50QlP/Hn4G/BhYApwGnMrgfotZq8/43wQWRMQ5x33tAmbxTWiAiAhgA4Mz1LWZebjnlU7kMmAF8LuI2AV8E7g2Iv6rz6VOJDP3MDhzzqWXmy4DPgc8mpl/zsw/ABuZxT+woMf4M3Mfg5+U90XEkoj4O+CfGJxRZ7PHgfOAf8zMA8MO7tl64K8Z3HReDfwb8BPgH/pcqoONwO0RcUZEnAp8DXih550+0dStk+3ArRGxICJOAdYyuA9r1ur7ob7bgJMYPETy78CtmTlrz/wRcTbwFQYh7YqIiamPm3pe7WNl5v7M3HXsg8GvWgczc3ffuw1xP4OHJd8EtgCvAN/pdaPhvgx8CdgNbAWOAF/vdaMhwot5SDX1feaX1BPjl4oyfqko45eKMn6pKOOXipoV8UfEur53mK65tvNc2xfcubVZET8wZ75hx5lrO8+1fcGdm5ot8UuaYc2e4RcRc+6pg2NjY52PPXr0KPPmdfvZOTk5+WlXGpnMZPCapG5OOqn/l88fPnx4Wn8nBw60eanFwoULOx87OTnJ/PnzOx9/8ODBT7PSUJk59C97Vl9j7ONM5xs7XWeccUaTuXv27GkyF5hW0NOxatWqJnNb2rx5c5O5K1eubDIX4I03Rn8piCNHjnQ6zpv9UlHGLxVl/FJRxi8VZfxSUZ3in2vvrCNpuK4P9R3/zjqrgZ9ExGuz+ZJbkk5s6Jl/Lr6zjqThutzsn/PvrCPpL3W52d/5nXWmXtE0Z17YIFXWJf7O76yTmesZXCt+Tj63X6qky83+OfnOOpJObGj8c/iddSSdQNcn+cypd9aRNFynx/kz8wPg6sa7SJpBPr1XKsr4paKMXyrK+KWiml7Dr+sFLqej5cUwTznllCZzjx492mQuwAcffNBk7t69e5vM3bZtW5O5ML0LsE7H5Zdf3mQuwNatW0c+s2sjnvmlooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyqq2aW7lyxZwgUXXDDyuRMTEyOfeczrr7/eZO6dd97ZZC7Anj17msx98cUXm8y9+eabm8wF2LFjR5O5V1/d7m0qn3rqqZHPPHToUKfjPPNLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRQ2NPyIWRcSGiNgRER9GxCsRcdVMLCepnS5n/gXATuBS4GTgbuDpiFjRbi1JrQ19hl9m7gPuPe5LL0TEduCLwG/brCWptWn/zh8Ry4HPA78Z/TqSZsq0ntsfEWPAj4CnMvONj/nv64B1AAsXLhzJgpLa6Hzmj4h5wA+AQ8BXP+6YzFyfmeOZOT42NjaiFSW10OnMHxEBbACWA2sy83DTrSQ11/Vm/+PAecAVmXmg4T6SZkiXx/nPBr4CrAZ2RcTE1MdNzbeT1EyXh/p2ADEDu0iaQT69VyrK+KWijF8qyviloppdvTczO19FdDoWLGi2Mk8++WSTubfcckuTuQCDp2CM3tGjR5vMXbVqVZO5MLhidAvbt29vMhfghhtuGPnMZ599ttNxnvmlooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyoqMrPJ4LGxsVy2bNnI5y5evHjkM485+eSTm8zdunVrk7kA+/btazK31ff5oosuajIX4L333msy9+KLL24yF+Cuu+4a+cxrrrmGzZs3D72mu2d+qSjjl4oyfqko45eKMn6pKOOXijJ+qahpxR8R50TEwYj4YauFJM2M6Z75HwN+3WIRSTOrc/wRcSOwF/h5u3UkzZRO8UfEUuA+4Btt15E0UxZ0PO5+YENm7oz45KcMR8Q6YB3AvHnelyjNZkPjj4jVwBXAhcOOzcz1wHoYvLDn/7ydpGa6nPkvA1YAv5s6638GmB8Rf5OZf9tuNUktdYl/PfAfx33+TQY/DG5tsZCkmTE0/szcD+w/9nlETAAHM3N3y8UktdX1Dr//lZn3NthD0gzzLnmpKOOXijJ+qSjjl4qa9h1+XS1dupQrr7xy5HO3bds28pnHTExMNJn7zjvvNJkLsHbt2iZzn3vuuSZz16xZ02QuwBNPPNFk7kMPPdRkLsADDzww8pm7d3d7IM4zv1SU8UtFGb9UlPFLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VJTxS0UZv1SU8UtFGb9UVGS2eSftRYsW5VlnndVi7shnHrNkyZImc1999dUmcwHOP//8JnMvvHDoO7J/Khs3bmwyF9r9/Z177rlN5gK8/PLLTeZmZgw7xjO/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VFTn+CPixojYEhH7ImJbRFzScjFJbXV6i+6IuBJ4ELgB+E/gzJZLSWqvU/zAt4H7MvNXU5+/3WgfSTNk6M3+iJgPjAOnR8TWiHgrIh6NiJParyeplS6/8y8HxoDrgEuA1cCFwLc+emBErIuITRGxaXJycqSLShqtLvEfmPrzkcz8fWa+D3wXWPPRAzNzfWaOZ+b4/PnzR7mnpBEbGn9m7gHeAtq8/E9SL7o+1LcRuD0izoiIU4GvAS+0W0tSa13v7b8fOA14EzgIPA18p9VSktrrFH9mHgZum/qQ9P+AT++VijJ+qSjjl4oyfqko45eK6vpQ37RlJocPHx753EOHDo185jHnnXdek7l79+5tMhdg8eLFTeY+88wzTeYuXLiwyVyAffv2NZm7bdu2JnNh0MmojY+PdzrOM79UlPFLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VJTxS0UZv1SU8UtFGb9UlPFLRRm/VFSzq/cuWrSIlStXjnzuPffcM/KZx6xfv77J3EceeaTJXIDrr7++ydznn3++ydwHH3ywyVyALVu2NJm7c+fOJnMBHn744ZHPfPfddzsd55lfKsr4paKMXyrK+KWijF8qyvilooxfKqpT/BGxIiJ+GhF7ImJXRDwaEc2eIyCpva5n/u8D7wFnAquBS4HbWi0lqb2u8a8Ens7Mg5m5C3gR+EK7tSS11jX+7wE3RsTiiPgscBWDHwCS5qiu8f+CwZn+T8BbwCbguY8eFBHrImJTRGw6fPjw6LaUNHJD44+IecDPgB8DS4DTgFOBv3iFRmauz8zxzBwfGxsb9a6SRqjLmX8Z8Dng0cz8c2b+AdgIrGm6maSmhsafme8D24FbI2JBRJwCrAVea72cpHa6/s7/ZeBLwG5gK3AE+HqrpSS11+mJOpn5KnBZ21UkzSSf3isVZfxSUcYvFWX8UlHGLxXV7GW5k5OTTExMjHzu7bffPvKZx+zfv7/J3NNPP73JXICXXnqpydzrrruuydyWl8E+cuRIk7lr165tMhfgjjvuaDZ7GM/8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRkZltBkfsBnZ0PPw04P0mi7Qz13aea/uCO39aZ2fm0EtGN4t/OiJiU2aO973HdMy1nefavuDOrXmzXyrK+KWiZkv86/te4FOYazvPtX3BnZuaFb/zS5p5s+XML2mGGb9UlPFLRRm/VJTxS0X9D1S4+Ko7DI8/AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "np.fill_diagonal(norm_conf_mx, 0)\n", "plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 8, 9 열이 상당히 밝음 -> 많은 이미지가 8과 9로 잘못 분류되었음\n", "- 8, 9 행도 밝음 -> 숫자 8과 9가 다른 숫자들과 혼돈\n", "- 오차 행렬을 분석하면 분류기의 성능 향상 방안에 대한 통찰을 얻을 수 있음
\n", "이 그래프를 살펴보면 3과 5가 서로 혼돈되고 8과 9를 더 잘 분류할 수 있도록 개선할 필요가 있어 보임\n", " - 이 숫자들에 대한 훈련 데이터를 더 모음\n", " - 분류기에 도움 될 만한 특성을 더 찾아봄(동심원의 수를 세는 알고리즘 - 8은 두 개, 6은 하나, 5는 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "개개의 에러를 분석하기 위해 3과 5의 샘플을 그려보겠습니다." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['NanumBarunGothic'] not found. Falling back to DejaVu Sans\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cl_a, cl_b = 3, 5\n", "X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)] #3으로 정확히 분류\n", "X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)] #5로 잘못 분류\n", "X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)] #3으로 잘못 분류\n", "X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)] #5로 정확히 분류\n", "\n", "plt.figure(figsize=(8,8))\n", "plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n", "plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n", "plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n", "plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 대부분의 잘못 분류된 이미지는 에러인 것 같고 그 원인은 선형 모델인 SGDClassifier를 사용했기 때문
\n", "- 선형 분류기는 클래스마다 픽셀에 가중치를 할당하고 새로운 이미지에 대해 단순히 픽셀 강도의 가중치 합을 클래스의 점수로 계산
\n", "- 따라서, 이 분류기는 이미지의 위치나 회전 방향에 매우 민감
\n", "- 3과 5의 에러를 줄이는 한 가지 방법은 이미지를 중앙에 위치시키고 회전되어 있지 않도록 전처리 하는 것" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.6 다중 레이블 분류" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **다중 레이블 분류**multilabel classification : 여러 개의 이진 레이블을 출력하는 분류 시스템
\n", "얼굴 인식 분류기를 예로 들면, 같은 사진에 여러 사람이 등장한다면 인식된 사람마다 레이블을 하나씩 할당해야 함(즉, '앨리스 있음, 밥 없음, 찰리 있음')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[False, True],\n", " [False, False],\n", " [False, False],\n", " ...,\n", " [False, False],\n", " [False, False],\n", " [ True, True]])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "y_train_large = (y_train >= 7) #숫자가 7 이상인지\n", "y_train_odd = (y_train % 2 == 1) #홀수인지\n", "y_multilabel = np.c_[y_train_large, y_train_odd] #두 개의 1차원 배열을 칼럼으로 세로로 붙여서 2차원 배열 만들기\n", "y_multilabel #다중 타깃 레이블이 담긴 배열" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=None, n_neighbors=5, p=2,\n", " weights='uniform')" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn_clf = KNeighborsClassifier()\n", "knn_clf.fit(X_train, y_multilabel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "KNeighborsClassifier는 다중 레이블 분류를 지원하지만 모든 분류기가 그런 것은 아님" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[False, True]])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn_clf.predict([some_digit]) #숫자 5 예측" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다음 코드는 다중 레이블 분류기를 평가하기 위해 모든 레이블에 대한 F1 점수의 평균을 계산
\n", "average=\"macro\" 옵션은 모든 클래스의 FP, FN, TP 총합을 이용해 F1 점수를 계산" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.\n", "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 21.8min finished\n" ] }, { "data": { "text/plain": [ "0.97709078477525" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3, verbose=3, n_jobs=-1)\n", "f1_score(y_multilabel, y_train_knn_pred, average=\"macro\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 코드는 모든 레이블의 가중치가 같다고 가정한 것
\n", "앨리스 사진이 밥이나 찰리 사진보다 많다면 앨리스 사진에 대한 분류기의 점수에 더 높은 가중치를 둘 것. 간단한 방법은 레이블에 클래스의 **지지도**support(즉, 타깃 레이블에 속한 샘플 수)를 가중치로 줌. 이렇게 하려면 이전 코드에서 average=\"weighted\"로 설정" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.7 다중 출력 분류" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **다중 출력 다중 클래스 분류**multioutput-multiclass classification(또는 **다중 출력 분류**multioutput classification) : 다중 레이블 분류에서 한 레이블이 다중 클래스가 될 수 있도록 일반화한 것(즉, 값을 두 개 이상 가질 수 있음)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지에서 노이즈를 제거하는 시스템은 분류기의 출력이 다중 레이블(픽설당 한 레이블)이고 각 레이블은 여러 개의 값을 가짐(0부터 255까지 픽셀 강도). 그러므로 이는 다중 출력 분류 시스템임" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "#넘파이의 randint() 함수를 사용하여 픽셀 강도에 노이즈를 추가\n", "noise = np.random.randint(0, 100, (len(X_train), 784)) #파라미터 : 0~99까지의 랜덤 숫자, 행렬 사이즈\n", "X_train_mod = X_train + noise\n", "noise = np.random.randint(0, 100, (len(X_test), 784))\n", "X_test_mod = X_test + noise\n", "\n", "#타깃 이미지는 원본 이미지\n", "y_train_mod = X_train\n", "y_test_mod = X_test" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['NanumBarunGothic'] not found. Falling back to DejaVu Sans\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAC+CAYAAAAhkiQIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEk1JREFUeJzt3UlsFdQXx/FDoSOlZZCpLfNUqIC1DgiWSYVE4kLQuHSnW8NOISaKGk1UMDEacWPixhjjtDEhyFhA5qG0UFpa5gKFzlCgUP7Lf/yf3/3zKsOT+76f5S/n9b2+Po4v3nvP7XP79m0DAMQjLdkvAABwb9HYASAyNHYAiAyNHQAiQ2MHgMjQ2AEgMjR2AIgMjR0AIkNjB4DI0NgBIDL9kvGk3d3dbo5BVVWVrJ04caLLTp06JWtPnDjhsscff1zW7t+/32Xl5eWytq2tzWVZWVmy9sqVKy6rr6+Xtf36+bdfjXgoKCiQj1d5nz59ZO3Vq1dd1traKmsLCwtdpt5bM/33aW9vl7Xqd9uzZ4+s7du3r8seeeQRWZuenu6yadOm6Tfi/mNGB+63O362+cYOAJGhsQNAZGjsABCZpPw/9oaGBpddvHhR1o4bN85lY8eOlbVDhw512aFDh2TtqFGjXJaWpv87N2zYMJdt3bpV1vbv399l169fl7VDhgxxmfr/3j09PfLxZ86ccdmYMWNk7V9//eWyuXPnytoNGza4LCMjQ9YOHDjQZTU1NbJW/b/wSZMmydra2lqXtbS0yNo5c+bIHEhVfGMHgMjQ2AEgMjR2AIgMjR0AIkNjB4DI9EnGZda3xZNeu3ZN1qqTnIMGDZK1XV1dLgvt5tixY4fL5s2bJ2srKytdduHCBVk7e/Zsl4V2xagdNOr1Hj58WD5e7aqprq6WtWpnj9qlYmbW3d3tsvPnzydcq07UmundTOrva2aWn5/vsvHjx8ta9f5mZmZy8hSx4uQpAKQaGjsARIbGDgCRobEDQGSSsnh66dIl96RHjhyRtSUlJS47fvy4rB0+fLjLTp48KWufeeYZl4WO7qvFwNAiZV5enssaGxtlrVooHTx4sMvUwqeZHikQGm17+vRplxUXF8taNVpBLSCb6TEOvRmBUFZWJmubm5tdphZqQ0aPHs3iKWLF4ikApBoaOwBEhsYOAJGhsQNAZGjsABCZpOyKqaysdE86efJkWVtRUeGyWbNmyVp1mURubq6szc7OdpnaiWGmL4NQu1/MzAYMGOCypqYmWZuZmekydal36OLszs5Ol4Uus1bH7kPvjdqZE6pVu3A2btwoa9VIAXVxtplZTk6Oy9RnwcystLTUZbm5ueyKQazYFQMAqYbGDgCRobEDQGRo7AAQmaQsnra3t7snDS0QXrx40WWhI+tqhrc6Sm+m578XFBQk/HNDIxDUsffp06fL2kuXLrlMLYiGFmrV67p165asvXz5sstCC60DBw5M6HWZmbW3t7tMzZk30wuw6vFmZg0NDS4LzXlX4yGys7NZPEWsWDwFgFRDYweAyNDYASAyNHYAiAyNHQAio7cZ3Getra0u681x/v3798tatcPixRdfTPh17du3T+Z9+/Z12YQJE2St2hGyffv2hGvVEf2RI0fKxyuhXSZqpMDEiRNl7ZYtW1ym/g5meixCW1ubrFUXpIR+rhr5ENrxoz47oVEFQCrgGzsARIbGDgCRobEDQGRo7AAQmaSMFNiyZYt70tBxfnUMXc08N9PH3kM326ufG1rAVT9XzRY308f/b9y4IWs3bdrksoULF7qsvr5ePn7cuHEu6+jokLVqTIAa12Bm1tjY6DI189zM7OjRoy5bv369rJ07d67L0tL0dwu1sJueni5r1fz3559/npECuOd+/vlnmS9dutRlobEjU6dOvduXwUgBAEg1NHYAiAyNHQAiQ2MHgMjQ2AEgMknZFbN9+3b3pKNHj5a16sIFdcTfzGzmzJkuO3PmjKxVuy62bt0qa3uzm0M934YNG2TtgQMHXHb27FmX7d69Wz5eXeAxefJkWVtTU+MytaPFzCwjI8NloR006m8R2rWkLjcZPHiwrFXPp/4OZmaffvqpy4qKitgV8wCsXr1a5urzpvzyyy8yLy4udlloN8nmzZsTfn7V70IXztxt7bBhw2St+syHatXvZuyKAYDUQ2MHgMjQ2AEgMjR2AIhMUhZPGxoa3JOGFueysrJcFpojruZ95+fny9rq6mqXqRngZmYzZsxw2d69e2VtZWWly3744QdZq+bKq9EKofEFly9fdlnod1C1RUVFslYt1oYWRLu6ulxWUlIia9Xs9g8//FDWXrhwwWVqUdfMbPHixaqWxdN7TB2nf+WVV2StWmR8kAuXD7q2vLzcZZ9//rmsVXcuqMzMLCcnR8UsngJAqqGxA0BkaOwAEBkaOwBEhsYOAJHpl4wnVTs3pk2bJmv79fMvUV2sYKZXkEeMGCFrp0yZ4rLQmICmpiaXPffcc7J2586dLqutrZW1zz77rMvUhRhVVVXy8Wp3UGgHjfrd5syZI2vV0f3QkWd19Hv27NmyVl3gcf78eVm7YMEClx07dkzW1tXVuSz0ecI/p47/92ZXXVlZ2V09f2ikgNpRoj6Xvf25aqfLw4Jv7AAQGRo7AESGxg4AkaGxA0BkkrJ4qmaWh2asqwXRCRMmyFq1QHj16lVZ29HR4bKWlhZZqxYDQwu4P/30k8tCs+bV0Xs137q7u1s+vrW11WV5eXmyVv1uocXiSZMmuSy0GKlGPoTmz6tF2Vu3bslaNWJCzXM3M8vNzZU5/hm1WcDMrKKiwmWhI/YrV6502fvvv393LwwJ4xs7AESGxg4AkaGxA0BkaOwAEBkaOwBEJim7Ypqbm10WOgqvLldQOzFCtaEdE21tbS6bP3++rN23b5/LVqxYIWsPHjzostDFE4sWLXLZoUOHXBa6PEONRVA7jsz0hRj19fWyVo1LCB0dV7tXRo0aJWvVex4aVaBqQ5emhC4pwJ0dOXLEZUuXLpW1J0+edFnovVfjKtRzhY7z4+7wjR0AIkNjB4DI0NgBIDI0dgCITJ/ezFO+VzZt2uSeNHT0f8aMGS4LLbgdPXrUZeq2ezOzxx57LOHad99912Xnzp2TtWrmeE9Pj6wNLYr+r++++07mRUVFLgstiKrZ7aGFVnV0P7RweeXKFZepRV0zvZCdnp4ua3fs2OGyQYMGydrLly+7bN68eXe8yf0+efD/oBIQGhPw1FNPuUwtkprp8QGh/qFq1WiN3bt3y8ezIP5/3fGzzTd2AIgMjR0AIkNjB4DI0NgBIDJJOXmqFkofffRRWXvq1CmXhRY51aJdaD65WgxUJ0zNzNavX++yzMxMWdvZ2emy0tJSWavmqd+8edNlL730knz8b7/95jK12GymT56G5toPHTrUZX/++aesVbPbQwu4avZ6aA6/mpcfuiQ7dCE2/kv9OzLTC6W92VDRm9oTJ064LLQRYtOmTS5T8/yh8Y0dACJDYweAyNDYASAyNHYAiAyNHQAik5RdMWrnRui2enUMObTTRR1DV3Ohzcxu3LjhstraWlmrjsIPHz5c1qrX9t5778ladXS/pqbGZaGRAqtWrXLZmjVrZO2YMWMSen4zfURfzbo300e/q6qqZK2a0x4aP3D8+HGXnT59Wta2tLS4rLCwUNbi79TR/5Bly5a57O2330748Vu2bHHZxx9/LGtff/11l/3xxx+ytri4OOHXkCr4xg4AkaGxA0BkaOwAEBkaOwBEJinz2Ht6etyTtre3y9oBAwa4TC18mulLskMzx2fOnOmy0HtRXV3tsu7ublmrRgIUFBTIWvUz1GJkZWWlfPwnn3zism3btsnatWvXuuzVV1+Vtep3CP191HiIS5cuyVo1qiA0d7uurs5l48ePl7Vpaf77SU5ODvPYH2Jq1EDoc6VGjKjPWkSYxw4AqYbGDgCRobEDQGRo7AAQGRo7AEQmKbti1q1b55502rRpslYdeQ7tdLly5YrL1A4PM73bRh27NzNLT09P+Oe2tbW5TF0aYWY2duxYlw0ePNhloUsuVO3y5ctlrfodvv76a1mrdh319PTIWnXJxaJFi2St+j1Cl5A0NDS4bOLEibK2qalJ1bIr5iG2d+9ely1ZskTWqlEVofEDoV1YDxl2xQBAqqGxA0BkaOwAEBkaOwBEJimLpyYWmNSt5GZ6nrqaF25mdv36dZddvHhR1k6ePNlloQVCNaogNI9dLeQ1NjbKWjUzXP09Qn+jrq4ul/3++++y9ssvv3TZa6+9JmvVAmxoLMKuXbtcFlqgUq930KBBsra+vt5l2dnZsnbSpEkuy8/PZ/E0MmoshpnZm2++6bLVq1fL2rfeeuuevqYkYfEUAFINjR0AIkNjB4DI0NgBIDI0dgCITL9kPGlnZ6fL5s2bJ2s3bNjgstAQfTVqIFSrdliEdtCoXR4VFRWydurUqQk93sxsyJAhLjt16pTL8vPz5eNv3brlskOHDsna3Nxcl4UuLFEXGtTW1sraGTNmuEztaDHT4xbUGAgzPWrg7NmzslYdP1+4cKGsxcPr6NGjMldjR2pqau73y/lX4xs7AESGxg4AkaGxA0BkaOwAEJmkLJ4eOHDAZWVlZbJWHRcP3VY+c+ZMl6m53iF1dXUy79fPv02hxTm1+Nne3i5r1RxpNXP8yJEj8vFffPGFy7Zt2yZrlfLycpmrRVW10GumF5xDi7JPPPGEy0Lvzc6dO1329NNPy9q8vDyZ498v9NlesWKFy3799VdZm5OT47LQnQCpgm/sABAZGjsARIbGDgCRobEDQGRo7AAQmaTsipkyZYrL1G33ZnrHQ+jChYEDB7qsb9++slZdnnH48GFZq3ZjnDx5UtaqPHSBhzqm/+OPP7pszZo18vElJSUuGzNmjKx94403XPbCCy/IWnV0OzMzU9aq49xqzICZHgmQkZEha9V7FqpVO5FGjhwpa1PVBx98IPOVK1fel+dTu10++ugjl4V2uqhRE+qzZmb2zjvvuOzll1++00uMGt/YASAyNHYAiAyNHQAiQ2MHgMgkZfFUHQEOLYyo4/zqtnszs+rqapeFFv2amppctnnzZlm7bt06l4UWCNVR+GPHjslatRioFnVDM+VbW1td9tlnn8ladZxfPZeZ2dWrV12mZr+b6Tnve/bskbVq5n5opID6u4cW2NX7gL9TdxWYmS1btsxlofsD1L/Rb775JuHa27dvJ1RnpnvE999/L2tTfaFU4Rs7AESGxg4AkaGxA0BkaOwAEBkaOwBEpo9aqb7fqqur3ZN2dnbKWrU6XlRUJGvVBRxZWVmyVq3GHzx4UNZ+9dVXCdcWFBS4rL6+XtbevHnTZWosgtpNYma2fPlyl40fP17WqssvQu9NR0eHy9TlKGZmixcvTujxZmaVlZUuU5ejhGrVZ8HM7Pr16y578skn9XaL++/B/4NKwN69e2W+ZMkSl6nLU8wS3+nSm1q1K8fMbNWqVS4rLi6WtSnojp9tvrEDQGRo7AAQGRo7AESGxg4AkUnK4un58+fdkzY2NspalYduq6+pqXFZ6Di+Eho/oI7Nf/vtt7K2sLAwocebmV27ds1lQ4YMcVlogVEtUIXmz1dVVbmsrKxM1qoj+qHfYcCAAS5Ts7TN9CLnrl27Ev656r010+MHRo0axeLpP7R27dq7/hlTp051WXl5+V3/XJgZi6cAkHpo7AAQGRo7AESGxg4AkaGxA0BkkrIrpra21j2p2jFhpof+h46WHz582GXqggkzfSHG7NmzZe3+/ftdVlpaKmvV0f3QJRW7d+922axZsxJ+fFtbm8vUJRlmetSAGmlgpsc7hN5z9doaGhpkrdrhFLqwRP3c0O+mXlthYSG7YhArdsUAQKqhsQNAZGjsABAZGjsARMafxX4A1EKeWsw0M8vIyHCZWsw004ufoSP26jh/V1eXrB0xYoTLQguEakEzLy8v4deg5p73799fPl4d2w4thu/YscNl2dnZslbNtVd/BzOzBQsWuEy9X2Z68VSNLzAzO336tMtCv9v8+fNlDqQqvrEDQGRo7AAQGRo7AESGxg4AkaGxA0Bk/jUXbXR3d8tatWuio6ND1qoLHkI7P9LS/H/TQkfW09PTXTZs2DBZq0YjTJ8+XdZWVFS4LCsrK+HnUr+v2tFipncihbS0tLgs9PcpKSlxWXNzc8I/99y5c7JWjYJQl2+YmW3cuNFlixYtYqQAYsVIAQBINTR2AIgMjR0AIkNjB4DIJGXxFABw//CNHQAiQ2MHgMjQ2AEgMjR2AIgMjR0AIkNjB4DI0NgBIDI0dgCIDI0dACJDYweAyNDYASAyNHYAiAyNHQAiQ2MHgMjQ2AEgMjR2AIgMjR0AIkNjB4DI0NgBIDI0dgCIDI0dACJDYweAyNDYASAyNHYAiAyNHQAiQ2MHgMj8B8OXwRosN/S9AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#테스트 세트에서 이미지를 하나 선택\n", "some_index = 5500\n", "plt.subplot(121); plot_digit(X_test_mod[some_index])\n", "plt.subplot(122); plot_digit(y_test_mod[some_index])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda3/envs/mlbook/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['NanumBarunGothic'] not found. Falling back to DejaVu Sans\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEYCAYAAACDezmxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAABepJREFUeJzt3SGLlFscwOGZy9q22FxEBYNJi9isJoOwwWRQEASrsN/AZhCbYNFmMtpMJsMmnawGNSqKlgl7v8GcuTq/nXHv89Tz531P2P1xwuGd6cHBwQSg9M+6NwAcfUID5IQGyAkNkBMaICc0QE5ogJzQADmhAXJba3qv68hwNEyXGXKiAXJCA+SEBsgJDZATGiAnNEBOaICc0AA5oQFyQgPkhAbICQ2QExogJzRATmiAnNAAOaEBckID5IQGyAkNkBMaICc0QE5ogJzQADmhAXJCA+SEBsgJDZATGiAnNEBOaICc0AA5oQFyQgPkhAbICQ2QExogJzRATmiAnNAAOaEBckID5IQGyAkNkBMaICc0QE5ogJzQADmhAXJCA+SEBsgJDZATGiAnNEBOaICc0AA5oQFyQgPkhAbICQ2Q21r3BoD/Zj6fD2eOHTt2CDtZnhMNkBMaICc0QE5ogJzQADmhAXJCA+Tco+G3fPz4ceH6bDYbPuP06dPDmTNnzgxnfv78uXB9Z2dn+IwHDx4MZ/b39xeuP3/+fPiM3d3d4cyvX78Wrt+8eXP4jHv37g1nvnz5MpxZFScaICc0QE5ogJzQADmhAXJCA+SEBsgJDZCbHhwcrOO9a3kpk8l0Ol33FhhY0//k71rqD8qJBsgJDZATGiAnNEBOaICc0AA5oQFyQgPkfGHviHn8+PG6t7C0vb294cyFCxcOYSeTyeXLl4czZ8+ePYSdHE1ONEBOaICc0AA5oQFyQgPkhAbICQ2Q8+Grv8jolxInk8nk0qVLf/ye+Xw+nNnacgWLyWTiw1fAphAaICc0QE5ogJzQADmhAXJCA+SEBsi5dbUh3r59O5xZxWW8p0+fDmdcxmPVnGiAnNAAOaEBckID5IQGyAkNkBMaICc0QM4X9jbEkydPhjN37tw5hJ1MJmv6m+Dv5At7wGYQGiAnNEBOaICc0AA5oQFyQgPk3KPZENPpUtcRDsXDhw+HM1evXl24fu7cuVVth83mHg2wGYQGyAkNkBMaICc0QE5ogJzQADmhAXIu7P1Fvn37Npx5+fLlwvUbN26sZC+z2Wzh+u3bt4fPuH79+nBmmY99bW9vD2fIuLAHbAahAXJCA+SEBsgJDZATGiAnNEBOaICcC3sknj17Npy5devWcGZ3d3c48+LFi2W2RMOFPWAzCA2QExogJzRATmiAnNAAOaEBckID5FzYI/Hu3bvhzLVr14Yz79+/H87s7+8vXL948eLwGfw2F/aAzSA0QE5ogJzQADmhAXJCA+SEBsi5R8PafP78eThz8uTJP37Ozs7O0nviP3OPBtgMQgPkhAbICQ2QExogJzRATmiAnNAAua11b4D/rzdv3gxnTpw4MZxxIW/zOdEAOaEBckID5IQGyAkNkBMaICc0QE5ogJwLeyS+fv06nLl///5w5u7du6vYDmvmRAPkhAbICQ2QExogJzRATmiAnNAAOb9USWI6XeoHDIfW9PfJ8vxSJbAZhAbICQ2QExogJzRATmiAnNAAOaEBcj58tSHOnz8/nHn06NFw5sOHDwvXT506teyWFnr16tUfP+P169cr2Al/AycaICc0QE5ogJzQADmhAXJCA+SEBsgJDZBzYW9DzGaz4cyVK1cOYSfL2dvbW7j+48eP4TO2t7dXtR02nBMNkBMaICc0QE5ogJzQADmhAXJCA+Tco9kQy/wi43w+H858//594fqnT5+Gzzh+/PhwZlUf0OL/wYkGyAkNkBMaICc0QE5ogJzQADmhAXJCA+Smy1wUC6zlpcDKTZcZcqIBckID5IQGyAkNkBMaICc0QE5ogJzQADmhAXJCA+SEBsgJDZATGiAnNEBOaICc0AA5oQFyQgPkhAbICQ2QExogJzRATmiAnNAAOaEBcltreu9Sv24HHA1ONEBOaICc0AA5oQFyQgPkhAbICQ2QExogJzRATmiAnNAAOaEBckID5IQGyAkNkBMaICc0QE5ogJzQADmhAXJCA+SEBsgJDZATGiAnNEBOaIDcv05ArrZeQ7ZeAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "knn_clf.fit(X_train_mod, y_train_mod)\n", "clean_digit = knn_clf.predict([X_test_mod[some_index]])\n", "plot_digit(clean_digit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3.8 연습문제" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1 . MNIST 데이터넷으로 분류기를 만들어 테스트 세트에서 97% 정확도를 달성해보세요. 힌트: KNeighborsClassifier가 이 작업에 아주 잘 맞습니다. 좋은 하이퍼파라미터 값만 찾으면 됩니다.(weights와 n_neighbors 하이퍼파라미터로 그리드 탐색을 시도해보세요)." ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 6 candidates, totalling 30 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.\n", "[Parallel(n_jobs=-1)]: Done 30 out of 30 | elapsed: 623.9min finished\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise-deprecating',\n", " estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=None, n_neighbors=5, p=2,\n", " weights='uniform'),\n", " fit_params=None, iid='warn', n_jobs=-1,\n", " param_grid=[{'weights': ['uniform', 'distance'], 'n_neighbors': [3, 4, 5]}],\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n", " scoring=None, verbose=3)" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "#‘uniform’일 때는 np.mean 함수를 사용하여 단순 평균을 계산하고, ‘distance’일 때는 거리를 고려한 가중치 평균(average)을 계산\n", "param_grid = [{'weights': [\"uniform\", \"distance\"], 'n_neighbors': [3, 4, 5]}]\n", "\n", "knn_clf = KNeighborsClassifier()\n", "grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3, n_jobs=-1)\n", "grid_search.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 138, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n_neighbors': 4, 'weights': 'distance'}" ] }, "execution_count": 138, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9714" ] }, "execution_count": 139, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import accuracy_score\n", "\n", "y_pred = grid_search.predict(X_test)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2 . MNIST 이미지를 (왼, 오른, 위, 아래) 어느 방향으로든 한 픽셀 이동시킬 수 있는 함수를 만들어보세요. 그런 다음 훈련 세트에 있는 각 이미지에 대해 네 개의 이동된 복사본(방향마다 한 개씩)을 만들어 훈련 세트에 추가하세요. 마지막으로 이 확장된 데이터셋에서 앞에서 찾은 최선의 모델을 훈련시키고 테스트 세트에서 정확도를 측정해보세요. 모델 성능이 더 높아졌는지 확인해보세요! 인위적으로 훈련 세트를 늘리는 이 기법을 **데이터 증식** 또는 **훈련 세트 확장**training set expansion이라고 합니다." ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "from scipy.ndimage.interpolation import shift\n", "\n", "def shift_image(image, dx, dy):\n", " image = image.reshape((28, 28))\n", " shifted_image = shift(image, [dy, dx], cval=0, mode=\"constant\") #모드가 constant이면 경계 밖의 값이 cval값으로 채워짐\n", " return shifted_image.reshape([-1])\n", "\n", "X_train_augmented = [image for image in X_train] #numpy.ndarray 타입을 list 타입으로 변경\n", "y_train_augmented = [label for label in y_train]\n", "\n", "for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)): #오른, 왼, 아래, 위\n", " for image, label in zip(X_train, y_train):\n", " X_train_augmented.append(shift_image(image, dx, dy)) #이동시킨 이미지를 훈련 세트에 추가\n", " y_train_augmented.append(label)\n", "\n", "X_train_augmented = np.array(X_train_augmented)\n", "y_train_augmented = np.array(y_train_augmented)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "어떤 학습 알고리즘은 훈련 샘플의 순서에 민감해서 많은 비슷한 샘플이 연이어 나타나면 성능이 나빠지기때문에 훈련 세트를 섞어서 모든 교차 검증 폴드가 비슷해지도록 하여 이런 문제를 방지함" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [], "source": [ "shuffle_idx = np.random.permutation(len(X_train_augmented))\n", "X_train_augmented = X_train_augmented[shuffle_idx]\n", "y_train_augmented = y_train_augmented[shuffle_idx]" ] }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9763" ] }, "execution_count": 143, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn_clf = KNeighborsClassifier(**grid_search.best_params_)\n", "knn_clf.fit(X_train_augmented, y_train_augmented)\n", "y_pred = knn_clf.predict(X_test)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3 . **타이타닉**Titanic 데이터셋에 도전해보세요. 캐글Kaggle에서 시작하면 좋습니다(https://www.kaggle.com/c/titanic)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "승객의 나이, 성별, 승객 등급, 승선 위치 같은 속성을 기반으로 하여 승객의 생존 여부를 예측하는 것이 목표입니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "데이터를 적재합니다:" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")\n", "\n", "import pandas as pd\n", "\n", "def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n", " csv_path = os.path.join(titanic_path, filename)\n", " return pd.read_csv(csv_path)\n", "\n", "train_data = load_titanic_data(\"train.csv\")\n", "test_data = load_titanic_data(\"test.csv\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "훈련 세트에서 맨 위 몇 개의 열을 살펴 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S " ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Survived: 타깃. 0은 생존하지 못한 것이고 1은 생존을 의미
\n", "- Pclass: 승객 등급. 1, 2, 3등석
\n", "- SibSp: 함께 탑승한 형제, 배우자의 수
\n", "- Parch: 함께 탑승한 자녀, 부모의 수
\n", "- Cabin: 객실 번호
\n", "- Embarked: 승객이 탑승한 곳. C(Cherbourg), Q(Queenstown), S(Southampton)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "누락된 데이터가 얼마나 되는지 알아보겠습니다:" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 12 columns):\n", "PassengerId 891 non-null int64\n", "Survived 891 non-null int64\n", "Pclass 891 non-null int64\n", "Name 891 non-null object\n", "Sex 891 non-null object\n", "Age 714 non-null float64\n", "SibSp 891 non-null int64\n", "Parch 891 non-null int64\n", "Ticket 891 non-null object\n", "Fare 891 non-null float64\n", "Cabin 204 non-null object\n", "Embarked 889 non-null object\n", "dtypes: float64(2), int64(5), object(5)\n", "memory usage: 83.6+ KB\n" ] } ], "source": [ "train_data.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Age, Cabin, Embarked 속성의 일부가 null
\n", "- 특히 Cabin은 77%가 null. 이 속성은 무시
\n", "- Age는 19%가 null. null값은 중간값으로 채움
\n", "- Name과 Ticket 속성은 숫자로 변환하기가 까다롭기 때문에 이 두 속성은 무시\n", "- 변환시켜야할 범주형 특성 : Pclass, Sex, Embarked" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 전처리 파이프라인을 만듭니다. DataFrame으로부터 특정 속성만 선택하기 위해 이전 장에서 만든 DataframeSelector를 재사용하겠습니다:" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "from sklearn.base import BaseEstimator, TransformerMixin\n", "\n", "# 사이킷런이 DataFrame을 바로 사용하지 못하므로\n", "# 수치형이나 범주형 컬럼을 선택하는 클래스를 만듭니다.\n", "class DataFrameSelector(BaseEstimator, TransformerMixin):\n", " def __init__(self, attribute_names):\n", " self.attribute_names = attribute_names\n", " def fit(self, X, y=None):\n", " return self\n", " def transform(self, X):\n", " return X[self.attribute_names]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "숫자 특성을 위한 파이프라인을 만듭니다:" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[22. , 1. , 0. , 7.25 ],\n", " [38. , 1. , 0. , 71.2833],\n", " [26. , 0. , 0. , 7.925 ],\n", " ...,\n", " [28. , 1. , 2. , 23.45 ],\n", " [26. , 0. , 0. , 30. ],\n", " [32. , 0. , 0. , 7.75 ]])" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.impute import SimpleImputer\n", "\n", "imputer = SimpleImputer(strategy=\"median\")\n", "\n", "num_pipeline = Pipeline([\n", " (\"select_numeric\", DataFrameSelector([\"Age\", \"SibSp\", \"Parch\", \"Fare\"])),\n", " (\"imputer\", SimpleImputer(strategy=\"median\")),\n", " ])\n", "\n", "num_pipeline.fit_transform(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 범주형 특성을 위한 파이프라인을 만듭니다:" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 1., ..., 0., 0., 1.],\n", " [1., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 1., ..., 0., 0., 1.],\n", " ...,\n", " [0., 0., 1., ..., 0., 0., 1.],\n", " [1., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 1., ..., 0., 1., 0.]])" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.preprocessing import OneHotEncoder\n", "\n", "\n", "cat_pipeline = Pipeline([\n", " (\"select_cat\", DataFrameSelector([\"Pclass\", \"Sex\", \"Embarked\"])),\n", " (\"imputer\", SimpleImputer(strategy='most_frequent')),\n", " (\"cat_encoder\", OneHotEncoder(sparse=False)),\n", " ])\n", "\n", "cat_pipeline.fit_transform(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "마지막으로 숫자와 범주형 파이프라인을 연결합니다:" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import FeatureUnion\n", "preprocess_pipeline = FeatureUnion(transformer_list=[\n", " (\"num_pipeline\", num_pipeline),\n", " (\"cat_pipeline\", cat_pipeline),\n", " ])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 원본 데이터를 받아 머신러닝 모델에 주입할 숫자 입력 특성을 출력하는 전처리 파이프라인을 만들었습니다." ] }, { "cell_type": "code", "execution_count": 144, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[22., 1., 0., ..., 0., 0., 1.],\n", " [38., 1., 0., ..., 1., 0., 0.],\n", " [26., 0., 0., ..., 0., 0., 1.],\n", " ...,\n", " [28., 1., 2., ..., 0., 0., 1.],\n", " [26., 0., 0., ..., 1., 0., 0.],\n", " [32., 0., 0., ..., 0., 1., 0.]])" ] }, "execution_count": 144, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_titanic = preprocess_pipeline.fit_transform(train_data)\n", "X_train_titanic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "레이블을 가져옵니다:" ] }, { "cell_type": "code", "execution_count": 145, "metadata": {}, "outputs": [], "source": [ "y_train_titanic = train_data[\"Survived\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 분류기를 훈련시킬 차례입니다. RandomForestClassifier를 적용해 보겠습니다:\n", "최적의 하이퍼파라미터를 찾기위해 랜덤탐색을 시행" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "forest_clf = RandomForestClassifier(random_state=42)" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomizedSearchCV(cv=5, error_score='raise-deprecating',\n", " estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None,\n", " oob_score=False, random_state=42, verbose=0, warm_start=False),\n", " fit_params=None, iid='warn', n_iter=50, n_jobs=-1,\n", " param_distributions={'n_estimators': , 'bootstrap': [True, False]},\n", " pre_dispatch='2*n_jobs', random_state=42, refit=True,\n", " return_train_score='warn', scoring='neg_mean_squared_error',\n", " verbose=0)" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import RandomizedSearchCV\n", "from scipy.stats import randint\n", "\n", "param_distribs = {\n", " 'n_estimators': randint(low=1, high=100),\n", " 'bootstrap': [True, False],\n", " }\n", "\n", "rnd_search = RandomizedSearchCV(forest_clf, param_distributions=param_distribs,\n", " n_iter=50, cv=5, scoring='neg_mean_squared_error', \n", " random_state=42, n_jobs=-1)\n", "rnd_search.fit(X_train_titanic, y_train_titanic)" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.44064028507448966 {'n_estimators': 52, 'bootstrap': True}\n", "0.4519567135595372 {'n_estimators': 15, 'bootstrap': True}\n", "0.4393649125440716 {'n_estimators': 72, 'bootstrap': True}\n", "0.4431800195652587 {'n_estimators': 21, 'bootstrap': True}\n", "0.4393649125440716 {'n_estimators': 83, 'bootstrap': True}\n", "0.4393649125440716 {'n_estimators': 75, 'bootstrap': True}\n", "0.4393649125440716 {'n_estimators': 88, 'bootstrap': True}\n", "0.4494665749754947 {'n_estimators': 24, 'bootstrap': True}\n", "0.4519567135595372 {'n_estimators': 22, 'bootstrap': True}\n", "0.47614237371283535 {'n_estimators': 2, 'bootstrap': True}\n", "0.45071336398632533 {'n_estimators': 30, 'bootstrap': False}\n", "0.45071336398632533 {'n_estimators': 2, 'bootstrap': False}\n", "0.4519567135595372 {'n_estimators': 60, 'bootstrap': False}\n", "0.4419119768530779 {'n_estimators': 33, 'bootstrap': True}\n", "0.45071336398632533 {'n_estimators': 58, 'bootstrap': False}\n", "0.4519567135595372 {'n_estimators': 89, 'bootstrap': False}\n", "0.44064028507448966 {'n_estimators': 91, 'bootstrap': True}\n", "0.4419119768530779 {'n_estimators': 42, 'bootstrap': True}\n", "0.4519567135595372 {'n_estimators': 92, 'bootstrap': False}\n", "0.4519567135595372 {'n_estimators': 80, 'bootstrap': False}\n", "0.44064028507448966 {'n_estimators': 62, 'bootstrap': True}\n", "0.4531966520035264 {'n_estimators': 47, 'bootstrap': False}\n", "0.4531966520035264 {'n_estimators': 51, 'bootstrap': False}\n", "0.4544332072404845 {'n_estimators': 55, 'bootstrap': False}\n", "0.45071336398632533 {'n_estimators': 64, 'bootstrap': False}\n", "0.4702125984770659 {'n_estimators': 3, 'bootstrap': True}\n", "0.4431800195652587 {'n_estimators': 51, 'bootstrap': True}\n", "0.4431800195652587 {'n_estimators': 21, 'bootstrap': True}\n", "0.4431800195652587 {'n_estimators': 39, 'bootstrap': True}\n", "0.4469625634310624 {'n_estimators': 4, 'bootstrap': False}\n", "0.4393649125440716 {'n_estimators': 60, 'bootstrap': True}\n", "0.4581228472908512 {'n_estimators': 9, 'bootstrap': False}\n", "0.45566640681373577 {'n_estimators': 53, 'bootstrap': False}\n", "0.4519567135595372 {'n_estimators': 84, 'bootstrap': False}\n", "0.4519567135595372 {'n_estimators': 60, 'bootstrap': False}\n", "0.4419119768530779 {'n_estimators': 44, 'bootstrap': True}\n", "0.4457052822810143 {'n_estimators': 8, 'bootstrap': True}\n", "0.4457052822810143 {'n_estimators': 35, 'bootstrap': True}\n", "0.4519567135595372 {'n_estimators': 81, 'bootstrap': False}\n", "0.4544332072404845 {'n_estimators': 50, 'bootstrap': False}\n", "0.4469625634310624 {'n_estimators': 4, 'bootstrap': False}\n", "0.4494665749754947 {'n_estimators': 6, 'bootstrap': False}\n", "0.4469625634310624 {'n_estimators': 4, 'bootstrap': False}\n", "0.45071336398632533 {'n_estimators': 93, 'bootstrap': False}\n", "0.44821631782492505 {'n_estimators': 18, 'bootstrap': True}\n", "0.4519567135595372 {'n_estimators': 44, 'bootstrap': False}\n", "0.4494665749754947 {'n_estimators': 74, 'bootstrap': False}\n", "0.45071336398632533 {'n_estimators': 14, 'bootstrap': False}\n", "0.4431800195652587 {'n_estimators': 48, 'bootstrap': True}\n", "0.4393649125440716 {'n_estimators': 72, 'bootstrap': True}\n" ] } ], "source": [ "cvres = rnd_search.cv_results_\n", "for mean_score, params in zip(cvres[\"mean_test_score\"], cvres[\"params\"]):\n", " print(np.sqrt(-mean_score), params)" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'bootstrap': True, 'n_estimators': 72}" ] }, "execution_count": 119, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnd_search.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이를 사용해서 테스트 세트에 대한 예측을 만듭니다:" ] }, { "cell_type": "code", "execution_count": 146, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1,\n", " 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,\n", " 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,\n", " 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0,\n", " 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n", " 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1,\n", " 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,\n", " 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n", " 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n", " 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n", " 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", " 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,\n", " 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0,\n", " 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,\n", " 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,\n", " 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1])" ] }, "execution_count": 146, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forest_clf = RandomForestClassifier(bootstrap= True, n_estimators= 72, random_state=42)\n", "forest_clf.fit(X_train_titanic, y_train_titanic)\n", "\n", "X_test_titanic = preprocess_pipeline.transform(test_data)\n", "X_test_predict = forest_clf.predict(X_test_titanic)\n", "X_test_predict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 예측 결과를 CSV 파일로 만들어 캐글에 업로드하고 평가를 받아볼 수 있습니다. 하지만 교차 검증으로 모델이 얼마나 좋은지 먼저 평가하겠습니다." ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.817174838270344" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "\n", "forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n", "forest_scores.mean()" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame(test_data['PassengerId'])\n", "df.loc[:,'Survived'] = pd.Series(X_test_predict, index = df.index)\n", "df.to_csv(\"titanic_predict.csv\", index=None)" ] }, { "cell_type": "code", "execution_count": 170, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvived
08920
18930
28940
38951
48961
58970
68980
78990
89001
99010
109020
119030
129041
139050
149061
159071
169080
179091
189100
199111
209121
219131
229141
239151
249161
259170
269181
279191
289201
299210
.........
38812800
38912810
39012820
39112831
39212840
39312850
39412860
39512871
39612880
39712891
39812900
39912910
40012921
40112930
40212941
40312950
40412961
40512970
40612980
40712990
40813001
40913011
41013021
41113031
41213040
41313050
41413061
41513070
41613080
41713091
\n", "

418 rows × 2 columns

\n", "
" ], "text/plain": [ " PassengerId Survived\n", "0 892 0\n", "1 893 0\n", "2 894 0\n", "3 895 1\n", "4 896 1\n", "5 897 0\n", "6 898 0\n", "7 899 0\n", "8 900 1\n", "9 901 0\n", "10 902 0\n", "11 903 0\n", "12 904 1\n", "13 905 0\n", "14 906 1\n", "15 907 1\n", "16 908 0\n", "17 909 1\n", "18 910 0\n", "19 911 1\n", "20 912 1\n", "21 913 1\n", "22 914 1\n", "23 915 1\n", "24 916 1\n", "25 917 0\n", "26 918 1\n", "27 919 1\n", "28 920 1\n", "29 921 0\n", ".. ... ...\n", "388 1280 0\n", "389 1281 0\n", "390 1282 0\n", "391 1283 1\n", "392 1284 0\n", "393 1285 0\n", "394 1286 0\n", "395 1287 1\n", "396 1288 0\n", "397 1289 1\n", "398 1290 0\n", "399 1291 0\n", "400 1292 1\n", "401 1293 0\n", "402 1294 1\n", "403 1295 0\n", "404 1296 1\n", "405 1297 0\n", "406 1298 0\n", "407 1299 0\n", "408 1300 1\n", "409 1301 1\n", "410 1302 1\n", "411 1303 1\n", "412 1304 0\n", "413 1305 0\n", "414 1306 1\n", "415 1307 0\n", "416 1308 0\n", "417 1309 1\n", "\n", "[418 rows x 2 columns]" ] }, "execution_count": 170, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }