{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LlS4gyYqJVpn" }, "source": [ "# 로지스틱 회귀" ] }, { "cell_type": "markdown", "metadata": { "id": "m9MHDgx2JVpv" }, "source": [ "\n", " \n", "
\n", " 구글 코랩에서 실행하기\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "tQBDVUjBbIsL" }, "source": [ "## 럭키백의 확률" ] }, { "cell_type": "markdown", "metadata": { "id": "ILi_LPl9JVpw" }, "source": [ "### 데이터 준비하기" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "Mba6QeEmLn3r", "outputId": "752ed0d7-c1cf-4dea-9baa-7953bfa0fbf3" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Species Weight Length Diagonal Height Width\n", "0 Bream 242.0 25.4 30.0 11.5200 4.0200\n", "1 Bream 290.0 26.3 31.2 12.4800 4.3056\n", "2 Bream 340.0 26.5 31.1 12.3778 4.6961\n", "3 Bream 363.0 29.0 33.5 12.7300 4.4555\n", "4 Bream 430.0 29.0 34.0 12.4440 5.1340" ], "text/html": [ "\n", "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SpeciesWeightLengthDiagonalHeightWidth
0Bream242.025.430.011.52004.0200
1Bream290.026.331.212.48004.3056
2Bream340.026.531.112.37784.6961
3Bream363.029.033.512.73004.4555
4Bream430.029.034.012.44405.1340
\n", "
\n", " \n", "\n", "\n", "\n", "
\n", " \n", "
\n", "\n", "\n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 1 } ], "source": [ "import pandas as pd\n", "\n", "fish = pd.read_csv('https://bit.ly/fish_csv_data')\n", "fish.head()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UWJWlRCHVWUg", "outputId": "3ce7ab01-0ab1-45eb-ddce-c0a7378f5220" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Bream' 'Roach' 'Whitefish' 'Parkki' 'Perch' 'Pike' 'Smelt']\n" ] } ], "source": [ "print(pd.unique(fish['Species']))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "8JjMXc9wVE7C" }, "outputs": [], "source": [ "fish_input = fish[['Weight','Length','Diagonal','Height','Width']].to_numpy()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1T6C1d5iMzb8", "outputId": "fbe78230-14cd-4327-da8f-377e72774930", "scrolled": true }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[242. 25.4 30. 11.52 4.02 ]\n", " [290. 26.3 31.2 12.48 4.3056]\n", " [340. 26.5 31.1 12.3778 4.6961]\n", " [363. 29. 33.5 12.73 4.4555]\n", " [430. 29. 34. 12.444 5.134 ]]\n" ] } ], "source": [ "print(fish_input[:5])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "aB2oHhojTfWE" }, "outputs": [], "source": [ "fish_target = fish['Species'].to_numpy()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "dkllezAJW63K" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_input, test_input, train_target, test_target = train_test_split(\n", " fish_input, fish_target, random_state=42)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "H0ujq0BjXpfp" }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "ss = StandardScaler()\n", "ss.fit(train_input)\n", "train_scaled = ss.transform(train_input)\n", "test_scaled = ss.transform(test_input)" ] }, { "cell_type": "markdown", "metadata": { "id": "oAxk-V5kQcgc" }, "source": [ "### k-최근접 이웃 분류기의 확률 예측" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BYWTNPOdXOfr", "outputId": "4cd37f3b-5241-447a-a9f2-5ccfa71218b0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.8907563025210085\n", "0.85\n" ] } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "kn = KNeighborsClassifier(n_neighbors=3)\n", "kn.fit(train_scaled, train_target)\n", "\n", "print(kn.score(train_scaled, train_target))\n", "print(kn.score(test_scaled, test_target))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a879-O42RhFO", "outputId": "d7fffbf6-fb8d-420e-8000-56ad09c65c91" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Bream' 'Parkki' 'Perch' 'Pike' 'Roach' 'Smelt' 'Whitefish']\n" ] } ], "source": [ "print(kn.classes_)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EucmtF8HVOS_", "outputId": "a803d3da-7663-481d-88ef-24ddea427451" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Perch' 'Smelt' 'Pike' 'Perch' 'Perch']\n" ] } ], "source": [ "print(kn.predict(test_scaled[:5]))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OSDr8WSKXbUa", "outputId": "ecc49bd9-8a45-43b9-a9a2-19329f5c6bb8" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[0. 0. 1. 0. 0. 0. 0. ]\n", " [0. 0. 0. 0. 0. 1. 0. ]\n", " [0. 0. 0. 1. 0. 0. 0. ]\n", " [0. 0. 0.6667 0. 0.3333 0. 0. ]\n", " [0. 0. 0.6667 0. 0.3333 0. 0. ]]\n" ] } ], "source": [ "import numpy as np\n", "\n", "proba = kn.predict_proba(test_scaled[:5])\n", "print(np.round(proba, decimals=4))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Hk-ywsfKkf7t", "outputId": "51bc04ca-95f0-480f-8558-3fce405c8756" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[['Roach' 'Perch' 'Perch']]\n" ] } ], "source": [ "distances, indexes = kn.kneighbors(test_scaled[3:4])\n", "print(train_target[indexes])" ] }, { "cell_type": "markdown", "metadata": { "id": "Q9_wuI_0tEqL" }, "source": [ "## 로지스틱 회귀" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 449 }, "id": "8rdDSaZ5uji2", "outputId": "8ea6c0b2-1965-47df-be28-465fdfc493f1" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "z = np.arange(-5, 5, 0.1)\n", "phi = 1 / (1 + np.exp(-z))\n", "\n", "plt.plot(z, phi)\n", "plt.xlabel('z')\n", "plt.ylabel('phi')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "1J6LGKpUJbFE" }, "source": [ "### 로지스틱 회귀로 이진 분류 수행하기" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JeR5cA_fIe24", "outputId": "a84389ab-489b-4bac-f99e-1dd86b6a1b1e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['A' 'C']\n" ] } ], "source": [ "char_arr = np.array(['A', 'B', 'C', 'D', 'E'])\n", "print(char_arr[[True, False, True, False, False]])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "Khxh-3t5-2Tk" }, "outputs": [], "source": [ "bream_smelt_indexes = (train_target == 'Bream') | (train_target == 'Smelt')\n", "train_bream_smelt = train_scaled[bream_smelt_indexes]\n", "target_bream_smelt = train_target[bream_smelt_indexes]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 75 }, "id": "jEzP0aeXANra", "outputId": "0d636b60-eda6-4ead-c2e2-162dc1c97b06" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "LogisticRegression()" ], "text/html": [ "
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ] }, "metadata": {}, "execution_count": 16 } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "lr = LogisticRegression()\n", "lr.fit(train_bream_smelt, target_bream_smelt)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VtEWtsB7EIgm", "outputId": "6c374a95-a9f2-4983-faba-beb8aaf4c58e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Bream' 'Smelt' 'Bream' 'Bream' 'Bream']\n" ] } ], "source": [ "print(lr.predict(train_bream_smelt[:5]))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3H_qieV-_CTt", "outputId": "f38633e8-1b98-4b17-f86b-95685099e372" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[0.99759855 0.00240145]\n", " [0.02735183 0.97264817]\n", " [0.99486072 0.00513928]\n", " [0.98584202 0.01415798]\n", " [0.99767269 0.00232731]]\n" ] } ], "source": [ "print(lr.predict_proba(train_bream_smelt[:5]))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Mm60bpr7EQKU", "outputId": "1f439ee9-e3a5-44ff-a0b3-1462670728fb" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Bream' 'Smelt']\n" ] } ], "source": [ "print(lr.classes_)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1mvoYhUVQmFY", "outputId": "a8eee751-d8f9-47d7-a890-0cb480ba03c2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[-0.4037798 -0.57620209 -0.66280298 -1.01290277 -0.73168947]] [-2.16155132]\n" ] } ], "source": [ "print(lr.coef_, lr.intercept_)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SxrRy9m8A5Hy", "outputId": "a092eee9-5329-4de9-ee2a-65c2782e36ce" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[-6.02927744 3.57123907 -5.26568906 -4.24321775 -6.0607117 ]\n" ] } ], "source": [ "decisions = lr.decision_function(train_bream_smelt[:5])\n", "print(decisions)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SeuhSRuiA9yZ", "outputId": "f9127243-4293-4dca-b0c6-4f81af5d0fe1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[0.00240145 0.97264817 0.00513928 0.01415798 0.00232731]\n" ] } ], "source": [ "from scipy.special import expit\n", "\n", "print(expit(decisions))" ] }, { "cell_type": "markdown", "metadata": { "id": "6ee-s4l7EuVo" }, "source": [ "### 로지스틱 회귀로 다중 분류 수행하기" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7QugsbD2X8bf", "outputId": "1b5f1118-3b3a-43a6-b7d5-a5826475c001" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.9327731092436975\n", "0.925\n" ] } ], "source": [ "lr = LogisticRegression(C=20, max_iter=1000)\n", "lr.fit(train_scaled, train_target)\n", "\n", "print(lr.score(train_scaled, train_target))\n", "print(lr.score(test_scaled, test_target))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0taO0XnF9dha", "outputId": "e13af0b2-9a8f-4703-92f5-420bae40bc42" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Perch' 'Smelt' 'Pike' 'Roach' 'Perch']\n" ] } ], "source": [ "print(lr.predict(test_scaled[:5]))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pqZosYezZOi3", "outputId": "366bf3bc-5df7-4056-de6a-838320bc7cb7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[0. 0.014 0.841 0. 0.136 0.007 0.003]\n", " [0. 0.003 0.044 0. 0.007 0.946 0. ]\n", " [0. 0. 0.034 0.935 0.015 0.016 0. ]\n", " [0.011 0.034 0.306 0.007 0.567 0. 0.076]\n", " [0. 0. 0.904 0.002 0.089 0.002 0.001]]\n" ] } ], "source": [ "proba = lr.predict_proba(test_scaled[:5])\n", "print(np.round(proba, decimals=3))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CXASv4WU87UF", "outputId": "6f8b2061-88b0-4cd9-9f11-cd907919ef24" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Bream' 'Parkki' 'Perch' 'Pike' 'Roach' 'Smelt' 'Whitefish']\n" ] } ], "source": [ "print(lr.classes_)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1swPv6ZOZTjg", "outputId": "c52ceba6-b8ab-4f90-ecd6-38ebb6acbbef" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(7, 5) (7,)\n" ] } ], "source": [ "print(lr.coef_.shape, lr.intercept_.shape)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "s9iRz1iAd7Oe", "outputId": "07f4b83c-1211-45eb-f698-28a381388c57" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[ -6.5 1.03 5.16 -2.73 3.34 0.33 -0.63]\n", " [-10.86 1.93 4.77 -2.4 2.98 7.84 -4.26]\n", " [ -4.34 -6.23 3.17 6.49 2.36 2.42 -3.87]\n", " [ -0.68 0.45 2.65 -1.19 3.26 -5.75 1.26]\n", " [ -6.4 -1.99 5.82 -0.11 3.5 -0.11 -0.71]]\n" ] } ], "source": [ "decision = lr.decision_function(test_scaled[:5])\n", "print(np.round(decision, decimals=2))" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "49CcsDHZeJma", "outputId": "53f704df-743d-45ad-e76e-6729cecfbb03" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[0. 0.014 0.841 0. 0.136 0.007 0.003]\n", " [0. 0.003 0.044 0. 0.007 0.946 0. ]\n", " [0. 0. 0.034 0.935 0.015 0.016 0. ]\n", " [0.011 0.034 0.306 0.007 0.567 0. 0.076]\n", " [0. 0. 0.904 0.002 0.089 0.002 0.001]]\n" ] } ], "source": [ "from scipy.special import softmax\n", "\n", "proba = softmax(decision, axis=1)\n", "print(np.round(proba, decimals=3))" ] } ], "metadata": { "colab": { "name": "4-1 로지스틱 회귀.ipynb", "provenance": [] }, "kernelspec": { "display_name": "default:Python", "language": "python", "name": "conda-env-default-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" } }, "nbformat": 4, "nbformat_minor": 0 }