{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# **로지스틱 회귀를 이용한 Click Through Rate**\n", "1. 데이터 전처리 및 One-Hot Encoing\n", "1. **Logistic 회귀** 동작의 원리\n", "1. **Gradient descent** 기법, **Statistic Gradient descent** 기법\n", "1. **Logistic 회기** 분류기 학습 및 예측모델\n", "1. **L1, L2 정규화**를 이용한 Logistic 회귀\n", "1. On - Line Learning\n", "1. **Random Forest** 를 이용한 **feacture selection**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **1 One - Hot Encoding**\n", "1. **범주형 feacture** 를 **이진형 수치 feacture** 로 변환\n", "1. **K개의 값**을 갖는 **범주형** feacture를 **1~k 의** feacture로 매핑시킨다\n", "1. 변환된 범주형 데이터를 **원본으로** 되돌린다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **01 One - Hot Encoding 임베딩 데이터 만들기**\n", "1. **범주형 feacture** 를 **이진형 수치 feacture** 로 변환\n", "1. **K개의 값**을 갖는 **범주형** feacture를 **1~k 의** feacture로 매핑시킨다\n", "1. 변환된 범주형 데이터를 **원본으로** 되돌린다" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 1., 1., 0., 0.],\n", " [1., 0., 0., 0., 0., 1.],\n", " [1., 0., 0., 1., 0., 0.],\n", " [0., 1., 0., 0., 0., 1.],\n", " [0., 0., 1., 0., 0., 1.],\n", " [0., 0., 1., 0., 1., 0.],\n", " [0., 1., 0., 1., 0., 0.]])" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Dict 범주형 데이터를 One-Hot-encoding으로 변환\n", "from sklearn.feature_extraction import DictVectorizer\n", "dict_one_hot_encoder = DictVectorizer(sparse=False)\n", "\n", "X_dict = [{'interest': 'tech', 'occupation': 'professional'},\n", " {'interest': 'fashion', 'occupation': 'student'},\n", " {'interest': 'fashion', 'occupation': 'professional'},\n", " {'interest': 'sports', 'occupation': 'student'},\n", " {'interest': 'tech', 'occupation': 'student'},\n", " {'interest': 'tech', 'occupation': 'retired'},\n", " {'interest': 'sports', 'occupation': 'professional'}]\n", "\n", "X_encoded = dict_one_hot_encoder.fit_transform(X_dict)\n", "X_encoded" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'interest=fashion': 0,\n", " 'interest=sports': 1,\n", " 'interest=tech': 2,\n", " 'occupation=professional': 3,\n", " 'occupation=retired': 4,\n", " 'occupation=student': 5}\n" ] } ], "source": [ "# 범주형 Dataset Index 매핑내용 살펴보기\n", "from pprint import pprint\n", "pprint(dict_one_hot_encoder.vocabulary_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **02 Converting Data by Using Map Data**\n", "위에서 학습한 **dict_one_hot_encoder** 를 활용하여 데이터를 컨버팅/ 복원" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 1. 0. 0. 1. 0.]]\n" ] } ], "source": [ "# 위에서 매팽한 table 을 사용하여 새로운 데이터 인코딩\n", "new_dict = [{'interest': 'sports', 'occupation': 'retired'}]\n", "new_encoded = dict_one_hot_encoder.transform(new_dict)\n", "print(new_encoded)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'interest=sports': 1.0, 'occupation=retired': 1.0}]\n" ] } ], "source": [ "# new_encoded 인코딩 데이터를 원본형태로 되돌린다\n", "print(dict_one_hot_encoder.inverse_transform(new_encoded))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **03 Learning New Map Data**\n", "1. **new_encoded :** 새로운 매핑 데이터 추가하면, 결과적으로 **무시된다**\n", "1. 두개의 **dict** 데이터 중 **없는건 제외하고 나머지만 Converting** 된다" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 1. 0.]\n", " [0. 0. 1. 0. 0. 0.]]\n" ] } ], "source": [ "# 1개의 인덱스에 포함된 2개의 Dict 중, 1개만 converting 된다\n", "new_dict = [{'interest': 'unknown_interest', 'occupation': 'retired'},\n", " {'interest': 'tech', 'occupation': 'unseen_occupation'}]\n", "new_encoded = dict_one_hot_encoder.transform(new_dict)\n", "print(new_encoded)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **04 LabelEncoder 를 활용한 One-Hot-Encoding**\n", "1. **X_int** : One Hot 의 **인덱스값을** 출력한다\n", "1. 보다 간결하고 식별력이 높다" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5 1]\n", " [0 4]\n", " [0 1]\n", " [3 4]\n", " [5 4]\n", " [5 2]\n", " [3 1]]\n" ] } ], "source": [ "import numpy as np\n", "X_str = np.array([['tech', 'professional'],\n", " ['fashion', 'student'],\n", " ['fashion', 'professional'],\n", " ['sports', 'student'],\n", " ['tech', 'student'],\n", " ['tech', 'retired'],\n", " ['sports', 'professional']])\n", "\n", "from sklearn.preprocessing import LabelEncoder, OneHotEncoder\n", "label_encoder = LabelEncoder()\n", "X_int = label_encoder.fit_transform(X_str.ravel()).reshape(*X_str.shape)\n", "print(X_int)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 1. 1. 0. 0.]\n", " [1. 0. 0. 0. 0. 1.]\n", " [1. 0. 0. 1. 0. 0.]\n", " [0. 1. 0. 0. 0. 1.]\n", " [0. 0. 1. 0. 0. 1.]\n", " [0. 0. 1. 0. 1. 0.]\n", " [0. 1. 0. 1. 0. 0.]]\n" ] } ], "source": [ "# X_int 를 X_encoded 로 변환\n", "one_hot_encoder = OneHotEncoder()\n", "X_encoded = one_hot_encoder.fit_transform(X_int).toarray()\n", "print(X_encoded)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 1. 0.]\n", " [0. 0. 1. 0. 0. 0.]\n", " [0. 0. 0. 0. 0. 0.]]\n" ] } ], "source": [ "# Mapping 입력되지 않은 값들은 위와 동일하게 무시된다\n", "new_str = np.array([['unknown_interest', 'retired'],\n", " ['tech', 'unseen_occupation'],\n", " ['unknown_interest', 'unseen_occupation']])\n", "\n", "def string_to_dict(columns, data_str):\n", " data_dict = []\n", " for sample_str in data_str:\n", " data_dict.append({column : value for column, value in zip(columns, sample_str)})\n", " return data_dict\n", "\n", "columns = ['interest', 'occupation']\n", "new_encoded = dict_one_hot_encoder.transform(string_to_dict(columns, new_str))\n", "print(new_encoded)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **2 로지스틱 회귀 분류기**\n", "1. **실수값 데이터는 0~1 사이의 값으로** 변환한다\n", "1. $y(z) = \\frac{1}{1+exp(-z)}$ 대용량 데이터에 **확장성이 좋은** 알고리즘이다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **01 로지스틱 회귀의 동작원리**\n", "로지스틱 회귀는 나이브 베이즈 분류기처럼 **확률 기반 분류기이다**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 로지스틱 회귀 함수를 정의한다\n", "import numpy as np\n", "\n", "def sigmoid(input):\n", " return 1.0 / (1 + np.exp(-input))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAADTCAYAAADeUOthAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl8lOW5//HPlYWEfTEgsgYURMUNI4iK86hooa3oEa0ouOCCR0BFpQq1ooK/I1W00F/Bo8hSwQoVOTWogFCBoyBCgLATCAGyQExCQshCyHadP2acTilLkFmyXO/Xa17M/Sy5r0mG79zPMs8jqooxxphzFxbqAowxprawQDXGGD+xQDXGGD+xQDXGGD+xQDXGGD+xQDXGGD+xQDXGGD+xQDXGGD+xQDXGGD+JCHUB/hITE6OxsbGhLsMYU8ts2LAhR1VbVmXZWhOosbGxJCQkhLoMY0wtIyIHqrqsbfIbY4yfWKAaY4yfBCxQRWSmiGSJyLZTzBcR+ZOIJIvIFhHp4TPvYRHZ43k8HKgajTHGnwI5Qp0N9DvN/P5AF89jGPAegIi0AF4FegE9gVdFpHkA6zTGGL8IWKCq6v8CuadZ5E7gI3VbCzQTkQuAXwDLVDVXVfOAZZw+mAFISkpi9uzZAJSVleE4DnPnzgWguLgYx3GYP38+APn5+TiOw8KFCwHIycnBcRwWLVoEQGZmJo7jsGTJEgDS0tJwHIfly5cDkJKSguM4rFq1ytu34zisWbMGgG3btuE4DuvXrwcgMTERx3FITEwEYP369TiOw7Zt7sH7mjVrcByHpKQkAFatWoXjOKSkpACwfPlyHMchLS0NgCVLluA4DpmZmQAsWrQIx3HIyckBYOHChTiOQ35+PgDz58/HcRyKi4sBmDt3Lo7jUFZWBsDs2bNxHMf7u5w+fTp9+/b1tqdNm0b//v297SlTpjBgwABve9KkSQwcONDbnjhxIoMGDfK2J0yYwJAhQ7ztcePGMXToUG977NixDBs2zNsePXo0I0aM8LZHjRrFqFGjvO0RI0YwevRob3vYsGGMHTvW2x46dCjjxo3ztocMGcKECRO87UGDBjFx4kRve+DAgUyaNMnbHjBgAFOmTPG2+/fvz7Rp07ztvn37Mn36dG/bcRx773nee58uWECfm1wkp/9IWm4xk9+fTVzvG/nfHemsSc7h5bencWXP6/kyMY0vtxzimfF/5LJrejNvXSpz1x5g6EtvcmncDby3ci9TVyTzm2de45KeNzF5+W7eXbabu556mUuvu4W3luxi4uJd/PKx39L9htt544sdjF+0g74PPUf3Pv0Z9/k2fv/3rdx0/0i6u37NmM+28OKCzfS+50luv+s+79+uKu+9sxHKo/xtgTSfdrpn2qmm/xsRGYZ7dEtUVFRgqjSmjqlUJSOvmNz9uazbd5gfj5Ywe/U+wpsUkrhmH3t+LGDExxspi2pCyoYtHDiQR58/fENZZCOO7tzE0f253PLOKsKiGlK0cycFGfkMmfEDYZHRFG5PoTCzgOEfb0TCIyjcmkphdiFjFm4FoGDzQYpzivjDkl3u9q4sinOKmbx8j7u9J5uSw0VM/zYFESHvQB4lucV8si6VMBGyD+ZTkneMRZsPEiZCZlYBJfklrEjKQhDSc4vpEFURsN+dBPKK/SISC3yhqt1PMu8LYKKqfudp/wN4CXCAaFV9wzP9FeCYqk468Wf4iouLUzttypgzKymrYG92IWm5xRw4XExqrvuRkXeMrILjFB4vP+l60ZFhNImOpGn9SJrUj6RJdASNoyNpGBVB/chwoiPDqB8ZTv164URFhv/LtKiIcCLChchwISIsjPAwITI8zD0tzP1vRJgQ4TMtLAzCRRARwgREJMi/KTcR2aCqcVVZNpQj1AygvU+7nWdaBu5Q9Z2+MmhVGVOLHCkuZWNqHjsOHmVnZgE7Dx1lf04RlT7jqGYNIunQogHdLmjMTV1b0qpJFC0bRdGycRStGkcT07geTetHEhURHroXUkOEMlDjgZEiMg/3Aah8VT0kIkuB//I5EHU7MPZUP8QY80+HC4/zv3uyWbcvj4T9uezJKvTOa9+iPpe0bsKvr2hD1/MbEXteQ9q3aEDT+pEhrLh2CVigisgnuEeaMSKSjvvIfSSAqv438BXwSyAZKAaGeublisgEYL3nR41X1dMd3DKmTkvKLGDp9ky+2ZXF5vQjqELjqAh6dGzOnVe14ZqOLejetgmNoy04Ay2g+1CDyfahmrokM7+EzxMz+J9NGezKLEAErmzXjJsvbsXN3VpyWZumhIeFZp9jbVNT9qEaY86CqrJuXy6z1+zn6x0/UlGpXN2hGa8PuIxfXn4BLRvbmS6hZoFqTDWnqizfmcXk5bvZfvAozRpE8nifTtx/bQdiYxqGujzjwwLVmGpsZVIW73y9m60Z+cSe14A3776cu65qS/16dsS9OrJANaYa2p9TxPgvdvDNrizat6jP2/dcwX9c3ZaIcLueUXVmgWpMNVJWUcnUFclMW7GXehFh/P5Xl/BQ71jqRViQ1gQWqMZUE8lZBTw3fzNbM/IZcGUbfv+rS2jVJDrUZZmzYIFqTIipKn9dl8r4RTtoUC+c/x7Sg37dLwh1WeZnsEA1JoSOl1cw7u/bmZ+Qxk1dWzLp3ito1dhGpTWVBaoxIZJdcJwnPkogMe0IT99yEaP6drWT8Ws4C1RjQuDA4SIemrmOrKPHbRO/FrFANSbItmXk88isdZRXKh8/0YseHeyGFLWFBaoxQbT9YD6DP/yBRlERzHu0Jxe1ahTqkowfWaAaEyRJmQUM+fAHGtYLZ96w62jfokGoSzJ+ZmcLGxME+3KKGPzhWupFhPGJhWmtZSNUYwIst6iUobPWUakw/4nr6HieXdCktrIRqjEBVFJWweN/Wc+h/BKmPxTHhS1tn2ltZiNUYwJEVfntgi1sSjvCtAd6cE1HO5pf29kI1ZgAmfHdPhZtPsjo2y+m/+V2nmldYIFqTACs25fLm4t3cful5zPcuTDU5ZggsUA1xs+yCkoY8deNdGjRgEm/uTJk95M3wWf7UI3xI1XlxQVbOHqsjDmP9aSJ3Wm0TrERqjF+9PEPqaxMymZs/250a90k1OWYILNANcZPUrIL+X9f7qRPlxge6h0b6nJMCFigGuMH5RWVPPe3zURFhjHp3isJs8vw1Um2D9UYP5i1ej+b047w5weu5ny7bUmdZSNUY85Rel4x7y7bTd9LWvErO9+0TrNANeYcqCrjPt+OCLx+Z3c7RaqOs0A15hws3pbJN7uyeP62rrRtVj/U5ZgQC2igikg/EUkSkWQRGXOS+X8UkUTPY7eIHPGZV+EzLz6QdRrzcxQdL+f1Rdvp3rYJj1wfG+pyTDUQsINSIhIOTAVuA9KB9SISr6o7flpGVZ/zWf5p4GqfH3FMVa8KVH3GnKv3Vu7lx6PHeW/INUSE28aeCewItSeQrKopqloKzAPuPM3y9wOfBLAeY/wmPa+YD75N4a6r2tg9oYxXIAO1LZDm0073TPs3ItIR6AR84zM5WkQSRGStiNx1ivWGeZZJyM7O9lfdxpzRm4t3ESbwYr9uoS7FVCPVZTtlELBAVSt8pnVU1TjgAWCyiPzbJXtU9QNVjVPVuJYtWwarVlPHrd+fy5dbDvHkTRfSxg5EGR+BDNQMoL1Pu51n2skM4oTNfVXN8PybAqzkX/evGhMSqsobX+6kdZNonnR1DnU5ppoJZKCuB7qISCcRqYc7NP/taL2IdAOaA9/7TGsuIlGe5zHADcCOE9c1JtiW7fiRzWlHeO62LjSoZ180NP8qYO8IVS0XkZHAUiAcmKmq20VkPJCgqj+F6yBgnqqqz+qXAO+LSCXu0J/oe3aAMaFQUalM+jqJzjENGdijXajLMdVQQD9iVfUr4KsTpo07of3aSdZbA1weyNqMOVvxmzPY/WMhf37gajtNypyUvSuMqYLS8kr+uGwPl17QhF92t+/rm5OzQDWmCuYnpJGaW8xv+11sl+Yzp2SBaswZHC+vYNqKZOI6NsfpaqfnmVOzQDXmDBZuzOBQfgnP3NrFriZlTssC1ZjTKK+oZNrKZK5s15Q+XWJCXY6p5ixQjTmN+M0HScs9xshbbHRqzswC1ZhTqKhUpq5IplvrxtzarVWoyzE1gAWqMaewZFsme7OLGHnLRXZk31SJBaoxJ6Gq/P9v9tC5ZUP623mnpoosUI05iRVJWezKLGC4cxHhNjo1VWSBasxJzPxuP62bRHPnVW1CXYqpQSxQjTlBUmYB3yXn8GDvjkTad/bNWbB3izEnmLV6H1ERYTzQs0OoSzE1jAWqMT4OFx5n4aYM7u7RjuYN64W6HFPDWKAa4+OTdamUllfy6A2xoS7F1EAWqMZ4lJZXMmftAfp0iaHL+Y1DXY6pgSxQjfFYvO0QPx49zqM3dgp1KaaGskA1BveJ/DO+20fnlg1xdbFL9JmfxwLVGGBjah5b0vMZekMn+5qp+dksUI3BfSJ/k+gIBvZoG+pSTA1mgWrqvPS8YhZvO8T9vTrYraHNObFANXXenO8PICI81Ds21KWYGs4C1dRpxaXlfLIulX6XtaZts/qhLsfUcBaopk77bGMGR0vKefTG2FCXYmoBC1RTZ1VWKrNW7+PKdk3p0aF5qMsxtYAFqqmzVu3JJiW7iEdv7GT3izJ+cdpDmiLSGxgC9AEuAI4B24Avgbmqmh/wCo0JkJnf7eP8JlF2RX7jN6ccoYrIYuBxYCnQD3egXgr8HogGPheRAcEo0hh/2/1jAd/uyeGh3rHUi7ANNeMfp3snPaiqj6lqvKoeVNVyVS1U1Y2q+o6qOsCa0/1wEeknIkkikiwiY04y/xERyRaRRM/jcZ95D4vIHs/j4Z/9Co05iVmr9xMVEcb9ds1T40enDFRVzQEQkVdEpL3vPBEZ5rvMyYhIODAV6I97ZHu/iFx6kkXnq+pVnseHnnVbAK8CvYCewKsiYkcNjF/kFZWycGM6d/doSwu75qnxo6ps6zwNLBGRm32m/WcV1usJJKtqiqqWAvOAO6tY1y+AZaqaq6p5wDLcux1OKSkpidmzZwNQVlaG4zjMnTsXgOLiYhzHYf78+QDk5+fjOA4LFy4EICcnB8dxWLRoEQCZmZk4jsOSJUsASEtLw3Ecli9fDkBKSgqO47Bq1Spv347jsGaNe8C+bds2HMdh/fr1ACQmJuI4DomJiQCsX78ex3HYtm0bAGvWrMFxHJKSkgBYtWoVjuOQkpICwPLly3Ech7S0NACWLFmC4zhkZmYCsGjRIhzHISfH/fm2cOFCHMchP9+9i3v+/Pk4jkNxcTEAc+fOxXEcysrKAJg9ezaO43h/l9OnT6dv377e9rRp0+jfv7+3PWXKFAYM+OfenkmTJjFw4EBve+LEiQwaNMjbnjBhAkOGDPG2x40bx9ChQ73tsWPHMmzYMG979OjRjBgxwtseNWoUo0aN8rZHjBjB6NGjve1hw4YxduxYb3vo0KGMGzfO2x4yZAgTJkzwtvvecTdZ381n6A3uq0oNHDiQSZMmeecPGDCAKVOmeNv9+/dn2rRp/1y/b1+mT5/ubTuOY++9WvzeOxtVCdQM3KPMiSLyW8+0qhwSbQuk+bTTPdNONFBEtojIAp+RcJXWFZFhIpIgIgk//YGMOZ2yikr25xQRe15Duto1T42fiaqefgGRTap6tYhEA+8BjYDLVbXbGda7B+inqo972g8CvVR1pM8y5wGFqnpcRJ4E7lPVW0RkNBCtqm94lnsFOKaqk/69J7e4uDhNSEioyms2dVj85oM888kmZj4Sxy3dzg91OaYGEJENqhpXlWWrMkJNAFDVElUdCqwEqrLjKQPw3ffazjPNS1UPq+pxT/ND4JqqrmvMzzHzu310immI07VVqEsxtdAZA1VVnzihPVVVO1fhZ68HuohIJxGpBwwC4n0XEBHfEwAHADs9z5cCt4tIc8/BqNs904z52Tam5pGYdoShN8TaNU9NQJzyxH4RWQR8ACxR1bIT5nUGHgH2q+rMk62vquUiMhJ3EIYDM1V1u4iMBxJUNR54xnMuazmQ6/mZqGquiEzAHcoA41U19+e/TGPco9PG0REM7NEu1KWYWuqU+1BFpDXwPHA3kAdk4z6hvxOQDPxZVT8PUp1nZPtQzekcPHKMPm+t4LEbO/G7X14S6nJMDXI2+1BPOUJV1UzgRRFJB77FHabHgN2qWuyXSo0Jko++P4Cq8lDvjqEuxdRiVTko1Qr4FHgOaI07VI2pMbzXPO3emnbNG4S6HFOLVeWg1O+BLsAM3Ps494jIf4nIhQGuzRi/WLgxg/xjZTx6g90e2gRWla4Koe4drZmeRznQHFggIm8FsDZjztlP1zy9vG1Trulo3142gXXGQBWRZ0VkA/AWsBr3Sf1P4T5ndOBpVzYmxFbtzmZvdhGP97FrnprAq8otHlsAd6vqAd+JqlopIr8OTFnG+MeH36XQukk0v7zcrnlqAq8q+1BfPTFMfebtPNl0Y6qDnYeOsjr5MA9fH0tkuF3z1ASevctMrTXju33UjwznAbvmqQkSC1RTK2UVlBCfeJB749rRtEFkqMsxdYQFqqmV5n5/gLLKSu81T40JBgtUU+uUlFUwZ+0B+l5yPp1iGoa6HFOHWKCaWmfhxgzyist47EYbnZrgskA1tUplpTJz9T66t21Cr04tQl2OqWMsUE2t8o9dWSRnFfL4jZ3tRH4TdBaoptZQVaatTKZd8/r8+go7kd8EnwWqqTXWpuSyKfUIT97UmQg7kd+EgL3rTK0xbWUyMY3qcW9c+zMvbEwAWKCaWmFbRj7f7snh0Rs7ER0ZHupyTB1lgWpqhfdW7qVxVARDrrMr8pvQsUA1NV5KdiFfbTvEg7070iTavmZqQscC1dR4763cS73wMPuaqQk5C1RTo+3PKWLhpgwe6NWBlo2jQl2OqeMsUE2N9qdv9hAZLjzl2C3OTOhZoJoaa292IX/flMGD13WkVePoUJdjjAWqqbn+9I89REWE86TLRqemerBANTXSnh8LiN98kIevjyWmke07NdWDBaqpkd5dtpsGkeEMu6lzqEsxxssC1dQ4Gw7ksXhbJo/36UyLhvVCXY4xXgENVBHpJyJJIpIsImNOMv95EdkhIltE5B8i0tFnXoWIJHoe8YGs09QcqsqbX+2kZeMoG52aaiciUD9YRMKBqcBtQDqwXkTiVXWHz2KbgDhVLRaRp4C3gPs8846p6lWBqs/UTEu3/0jCgTz+6z8up2FUwN6+xvwsgRyh9gSSVTVFVUuBecCdvguo6gpVLfY01wLtAliPqeHKKir5w5JdXNSqEb+Js7eKqX4CGahtgTSfdrpn2qk8Biz2aUeLSIKIrBWRu062gogM8yyTkJ2dfe4Vm2rt47UH2JdTxNj+3ex6p6ZaqhbbTCIyBIgDXD6TO6pqhoh0Br4Rka2qutd3PVX9APgAIC4uToNWsAm67ILjvLNsN326xHBLt1ahLseYkwrkx3wG4Hul33aeaf9CRPoCLwMDVPX4T9NVNcPzbwqwErg6gLWaam7i4l2UlFXw2oDL7F5RptoKZKCuB7qISCcRqQcMAv7laL2IXA28jztMs3ymNxeRKM/zGOAGwPdglqlDEvbn8tnGdJ7o05kLWzYKdTnGnFLANvlVtVxERgJLgXBgpqpuF5HxQIKqxgNvA42ATz2jjlRVHQBcArwvIpW4Q3/iCWcHmDqivKKSVz7fTpum0Yy85aJQl2PMaQV0H6qqfgV8dcK0cT7P+55ivTXA5YGszdQMM77bx85DR3lvcA8a1KsWu/yNOSU7VGqqreSsQt5ZtptfXHY+/bq3DnU5xpyRBaqplioqlRcXbKZBvXAm3NXdDkSZGsEC1VRLs1bvY2PqEV6941K71qmpMSxQTbWz4+BR3lqaRN9LWnHXVaf7Logx1YsFqqlWikvLGfnJRprVj+QPA6+wTX1To9hhU1OtvB6/g305RXz8WC/OswtHmxrGRqim2li4MZ35CWkMdy7k+otiQl2OMWfNAtVUC5vTjjBm4Vau69yCUX27hrocY34WC1QTclkFJTw5ZwMtG0UxbfA1RNqVpEwNZftQTUgdK63gyTkbyD9WxmdPXW+3NDE1mgWqCZmyikqGf7yBxLQjvDe4B5e2aRLqkow5J7ZtZUKislJ56bMtrEjK5o27utOv+wWhLsmYc2aBaoKuslJ5NX47Czdm8PxtXRncq+OZVzKmBrBNfhNUlZXK7/5nK/PWp/GkqzNP2yX5TC1igWqCpqyikpcWbGHhpgyevuUinr+tq30TytQqFqgmKPKPlTH84w2sTj7MC7d15elbu4S6JGP8zgLVBFxabjGPzl7Pvpwi3r7nCu6Na3/mlYypgSxQTUB9vT2T0Z9uBuCjR3vaV0pNrWaBagKipKyCSUuT+PC7fVzetilTH+hBh/MahLosYwLKAtX43YYDufx2wRZSsot4qHdHXv7VJURFhIe6LGMCzgLV+M2R4lImL9/DX77fT5um9fno0Z7c1LVlqMsyJmgsUM05K6uoZO7aA0xevoeCkjIevK4jL/brRqMoe3uZusXe8eZnKymr4NMN6by/ai/pecfo0yWGl391Cd1a23fyTd1kgWrOWlZBCZ8mpDN7zX6yC47To0Mz3rirO66uLe1EfVOnWaCaKimrqGR1cg7z1qWxfOePlFcqN1x0HlMGXUXvzudZkBqDBao5jZKyCr7fe5ivth7i6x0/kn+sjBYN6/HojZ0YdG17OrdsFOoSjalWLFCNV3lFJTsPFbB6bw6rk3NYty+X4+WVNI6K4LZLz6f/5RdwU9cYOwXKmFOwQK2jjpdXsD+nmJ2HjrI5/Qhb0vPZfjCfkrJKAC4+vzGDe3WkT9cYrr/wPAtRY6ogoIEqIv2AKUA48KGqTjxhfhTwEXANcBi4T1X3e+aNBR4DKoBnVHVpIGutbVSV/GNlHDxSwqH8YxzMLyE9t5i92YUkZxWSmltMpbqXjY4Mo3ubpgzu1ZEr2zfjus4taNU4OrQvwJgaKGCBKiLhwFTgNiAdWC8i8aq6w2exx4A8Vb1IRAYBfwDuE5FLgUHAZUAbYLmIdFXVikDVW12VVVRSXFpBcWk5Rcf/+W/h8XLyikvJKyolr7iMvKJScotLOVJcyuHCUg7ll3Cs7F9/XfXCw+gU05DL2jRlwJVtuLBVI7qe35gurRoRYTfGM+bcqWpAHkBvYKlPeyww9oRllgK9Pc8jgBxATlzWd7lTPRo1aqSzZs1SVdXS0lJ1uVw6Z84cVVUtKipSl8ul8+bNU1XVI0eOaI9eN+jotz7Qv61P1elfJ+olPa7TF96eoXPX7tc/fbFOu13dS5+fNEtnfJuif/j0W+16VS99dtJfdNqKZH1t7j/0oit76tPvzNU/LkvSl2Ys1s5XXKvD3/2rTly8U5+Z9rl2uvxaffKdefpa/DZ97N2/aftL43TwxHk64uMNetdrs7X1xT30169/rL/57zV60wvvaYuLrtIbx87Rmyet0MueeFcbxl6hscNnaMeXvtBW972hUe27a9unZrnb976uUe27a7sRc7TjS1/oBfeO08axV6gz/nO97/01+otn39LY7tfqlK8S9cstB/XNP3+o19/YRwsKClVVdc6cOepyubS0tFRVVWfNmqUul0t/8sEHH+itt97qbU+dOlX79evnbU+ePFnvuOMOb/vtt9/Wu+++29t+88039b777vO2x48fr4MHD/a2X3nlFX3kkUe87TFjxugTTzzhbb/wwgs6fPhwb/vZZ5/VZ5991tsePny4vvDCC972E088oWPGjPG2H3nkEX3llVe87cGDB+v48eO97fvuu0/ffPNNb/vuu+/Wt99+29u+4447dPLkyd52v379dOrUqd72rbfeqh988IG37XK5zuq953K59LPPPlNV1ezsbHW5XBofH6+qqocOHVKXy6WLFy9WVdXU1FR1uVy6bNkyVVXdu3evulwuXblypaqq7tq1S10ul65evVpVVbdu3aoul0vXrVunqqqbNm1Sl8ulmzZtUlXVdevWqcvl0q1bt6qq6urVq9XlcumuXbtUVXXlypXqcrl07969qqq6bNkydblcmpqaqqqqixcvVpfLpYcOHVJV1fj4eHW5XJqdna2qqp999pm6XC49cuSIqqrOmzdPXS6XFhUVqWrNfO8BCVrF3AvkJn9bIM2nnQ70OtUyqlouIvnAeZ7pa09Yt+2JHYjIMGAYQFRU1FkVl1VQwl/W7OfTw1uoKM4nO7uQj9buZ0HONioK88jOKWLuD6nUz95B+dFscg4X8cn6NOpn76LsSCaHc4v5W0Ia0VnNKDucTm7eMT7flEGDnGaUZmeQlV/C0h2ZNMpvRnn2YY4cK2X7oXyaRxzlWH4JpRWVHCutoAFQPzKceuFhtGkWTcsLmhBT0JiChvW455p2tO/Yif1bclmypxHPD7iM2A4d2LK2kL+mNeX9EddzcecOfLO0gnezVvHpU9cTExPDwoWH+FNiAx6+PpamTZtSsKMRkeFhhIXZqU3GBJKoamB+sMg9QD9VfdzTfhDopaojfZbZ5lkm3dPeizt0XwPWqupcz/QZwGJVXXCq/uLi4jQhIaHK9WUXHOdYaQUiEBYmCBAmQpiAeP51twUJ++e8MBH3Op55Py1vjKmdRGSDqsZVZdlAjlAzAN8rCbfzTDvZMukiEgE0xX1wqirrnpOWjc9uRGuMMWcSyCMR64EuItJJROrhPsgUf8Iy8cDDnuf3AN+oe8gcDwwSkSgR6QR0AdYFsFZjjDlnARuhevaJjsR9QCkcmKmq20VkPO6dvPHADGCOiCQDubhDF89yfwN2AOXACK2DR/iNMTVLwPahBtvZ7kM1xpiqOJt9qHbyoTHG+IkFqjHG+Emt2eQXkWzgwFmuFoP7ywShUpf7r8uvPdT91+XX/nP676iqVbqXT60J1J9DRBKqum/E+q89fdf1/uvyaw90/7bJb4wxfmKBaowxflLXA/UD679O9l3X+6/Lrz2g/dfpfajGGONPdX2EaowxfmOBaowxflLnA1VErhKRtSKSKCIJItIzBDU8LSK7RGS7iLwVgv5fEBEVkZgg9/u253VvEZHk1GCWAAAEK0lEQVT/EZFmQeizn4gkiUiyiIwJdH8n9N1eRFaIyA7P3/rZYPbvU0e4iGwSkS9C0HczEVng+bvvFJHeQez7Oc/vfZuIfCIifr/PT50PVOAt4HVVvQoY52kHjYjcDNwJXKmqlwGTgtx/e+B2IDWY/XosA7qr6hXAbtx3aggYn9vy9AcuBe733G4nWMqBF1T1UuA6YESQ+//Js8DOEPQL7nvMLVHVbsCVwapDRNoCzwBxqtod9wWbBvm7HwtUUKCJ53lT4GCQ+38KmKiqxwFUNSvI/f8ReBH37yGoVPVrVS33NNfivu5tIPUEklU1RVVLgXm4P8yCQlUPqepGz/MC3GHyb3eiCCQRaQf8CvgwmP16+m4K3IT7KnOoaqmqHgliCRFAfc+1lxsQgP/rFqgwCnhbRNJwjw4DOko6ia5AHxH5QURWici1wepYRO4EMlR1c7D6PI1HgcUB7uNkt+UJaqD9RERigauBH4Lc9WTcH6CVQe4XoBOQDczy7HL4UEQaBqNjVc3A/f87FTgE5Kvq1/7uJ6C3ka4uRGQ50Poks14GbgWeU9XPROQ3uD89+wax/wigBe5NwGuBv4lIZ/XT+Wxn6Pt3uDf3A+Z0/avq555lXsa9OfxxIGupLkSkEfAZMEpVjwax318DWaq6QUScYPXrIwLoATytqj+IyBRgDPBKoDsWkea4t0Y6AUeAT0VkyE+3WfKbqt7Nr7Y+gHz+eT6uAEeD3P8S4Gaf9l6gZRD6vRzIAvZ7HuW4P71bB/n1PwJ8DzQIQl9nvBNvEGqIxH3R9eeD2a+n7zdxj8r3A5lAMTA3iP23Bvb7tPsAXwap73uBGT7th4Bp/u7HNvnd+1Fcnue3AHuC3P/fgZsBRKQrUI8gXIlHVbeqaitVjVXVWNz/0Xqoamag+/6JiPTDvfk5QFWLg9BlVW7LEzDivpvjDGCnqr4brH5/oqpjVbWd5+89CPcth4YEsf9MIE1ELvZMuhX3XTmCIRW4TkQaeP4OtxKAA2J1YpP/DJ4Apnh2VJfguS11EM0EZnruAFsKPKyej9A64M9AFLDMc+fYtar6n4HqTE9xW55A9XcSNwAPAltFJNEz7Xeq+lUQawi1p4GPPR9oKcDQYHSq7l0MC4CNuLfGNhGAr6DaV0+NMcZPbJPfGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP8xALV1Cki8p+ea98misg+EVkR6ppM7WEn9ps6SUQigW+At1R1UajrMbWDjVBNXTUF93fZLUyN39h3+U2dIyKPAB2BkSEuxdQytslv6hQRuQb4C9BHVfNCXY+pXWyT39Q1I3Ff0HuF58BU0G8FYmovG6EaY4yf2AjVGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP8xALVGGP85P8AQsTIzyv0h5MAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "# -8~8 사이의 값으로 로지스틱 회귀모델을 구현\n", "import matplotlib.pyplot as plt\n", "plt.figure(figsize=(5,3))\n", "z = np.linspace(-8, 8, 1000)\n", "y = sigmoid(z)\n", "plt.plot(z, y)\n", "plt.axhline(y=0, ls='dotted', color='k')\n", "plt.axhline(y=0.5, ls='dotted', color='k')\n", "plt.axhline(y=1, ls='dotted', color='k')\n", "plt.yticks([0.0, 0.25, 0.5, 0.75, 1.0])\n", "plt.xlabel('z'); plt.ylabel('y(z)'); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **02 MSE를 최소로 하는 로지스틱 회귀**\n", "비용함수를 최소로(실질적으로는 **MSE기반의 비용함수를** 최소로)하는 값들을 예측한다" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/markbaum/Python/python/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: divide by zero encountered in log\n", " This is separate from the ipykernel package so we can avoid doing imports until\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARAAAADUCAYAAABUBHwDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGqxJREFUeJzt3Xl0XeV57/Hvc44ka7RmWbJkebaFRzxgBjOEIQ4QCjTlMrThMrUkKU2TG27akDSrTXvXvWm5hCSEpqEZmQkEAgsIBBxTqIsxtrGxjW3kSbZk2bIka5ZsWXr6x94WsmxLR8faZ+8tP5+1ztLROe8++/FZWj+/+917v6+oKsYYE4+I3wUYY8LLAsQYEzcLEGNM3CxAjDFxswAxxsTNAsQYEzfPAkREZorI+n6PFhH5qlf7M8YkniTiOhARiQI1wLmqWuX5Do0xCZGoQ5jLgR0WHsaMLokKkJuBpxK0L2NMgnh+CCMiKcA+YLaqHjjJ+3cDdwNkZGQsqqio8LQeY8yJ1q5dW6+qhcPdLhEBch1wj6ouG6rtxJlztWrbRk/rMcacSETWquri4W6XiEOYW4jx8OVor93YZ0yYeBogIpIBfBp4Ppb2dmewMeGS5OWHq2o7kB9zew9rMcaMvEBdiWodEGPCJVgBYn0QY0IlUAFi+WFMuAQqQCw/jAmXYAWIJYgxoRKsALE+iDGhEqgAsfwwJlwCFSCWH8aES7ACxBLEmFAJVoBYH8SYUAlWgFh+GBMqwQoQvwswxgxLoALEEsSYcAlUgNgYiDHh4vV8IDki8pyIbBWRLSJy/mDtbQzEmHDxdD4Q4AfAa6p6gzs3avpgjS0/jAkXzwJERLKBi4HbAVT1CHBk0I0sQYwJFS8PYSYDB4FfiMgHIvJTd4rDU7IxEGPCxcsASQIWAj9W1QVAO/CNgY1E5G4RWSMia3ptUmVjQsXLAKkGqlX1Pff353AC5Tiq+oiqLlbVxYh4WI4xZqR5FiCquh/YKyIz3ZcuBz4adBs7hDEmVLw+C/Nl4An3DMxO4I5BW1t+GBMqXi/rsB6IebUrdbZB7FDGmFAI1JWoAIeP9vpdgjEmRoELkK7uHr9LMMbEKHAB0mkBYkxoBC5AurrtEMaYsAhcgHQesR6IMWERvACxQxhjQiNwAXLYAsSY0AhcgFgPxJjwCFyA2CCqMeERuACxHogx4WEBYoyJW/AC5MhRv0swxsQocAHS2mUBYkxYBCpAoiIWIMaESKACJBIRWjq7/S7DGBMjT+cDEZHdQCvQAxxV1UHnBolGhJYuCxBjwsLrGckALlXV+lgaRkVosUMYY0IjUIcwUTuEMSZUvA4QBX4vImtF5O6TNei/rEP3kcM2iGpMiHgdIBeq6kLgKuAeEbl4YIP+yzqkp6XaGIgxIeJpgKhqjfuzDngBWDJY+2hEaDt8FFtgyphw8CxARCRDRLKOPQeWAZsG2yYaEVShza5GNSYUvDwLMw54wV2iIQl4UlVfG2yDaEToAZrauxmbmuxhacaYkeBZgKjqTmD+cLZJikToAerbD1Oen+5NYcaYEROo07hJEWdBqfrWwz5XYoyJRbACJOoGSNsRnysxxsQiWAESccqpb7MeiDFhEKgAEYHstGQO2iGMMaEQqAABKMhMsR6IMSERuAApzBpjPRBjQiJwAVKSnUZtc5ffZRhjYhC4ACnLTWN/SxdHe2x5B2OCLpAB0tOr1gsxJgQCGCDOFajVhzp9rsQYM5TABUhpThoANU0WIMYEXeACpCQnFRGoPtThdynGmCEELkDGJEUZl5XK3kbrgRgTdIELEIDy/HR2N7T7XYYxZgieB4iIREXkAxF5OdZtphdlsr2uDVWbmcyYIEtED+QrwJbhbDCtKJPmzm4O2iXtxgSapwEiImXAZ4GfDme76UVZAGw/0OZBVcaYkeJ1D+T7wN8Aw7qsdPq4TAC2H7QAMSbIvJxU+RqgTlXXDtGub12YgwcPAlCUNYas1CS27W/1qjxjzAjwsgeyFLjWXR/3aeAyEXl8YKP+68IUFhYCICLMGZ/NxppmD8szxpwuzwJEVe9T1TJVnQTcDPxBVT8f6/bzJmSzpbaFw0d7vCrRGHOaYgoQEXksltdG0rzSHLp71A5jjAmwWHsgs/v/IiJRYFGsO1HVt1T1muEUNq8sG4APq+0wxpigGjRAROQ+EWkF5olIi/toBeqAF70srCw3jbyMFDbsbfJyN8aY0zBogKjq/1PVLOB+VR3rPrJUNV9V7/OyMBFhYXkO7+9u9HI3xpjTEOshzMvu+raIyOdF5HsiMtHDugA4b0o+uxs6qG22G+uMCaJYA+THQIeIzAfuBXYAj3pWlev8qfkArNrZ4PWujDFxiDVAjqpzZ9t1wI9U9WEgy7uyHGcVjyU7LZl3d1iAGBNEsS6u3Soi9wG3AheJSARI9q4sRyQinD8ln3cq61FVRMTrXRpjhiHWHshNwGHgTlXdD5QB93tWVT+XnVVEbXMXm/e1JGJ3xphhiClA3NB4Ash273HpUlXPx0AALq8oQgTe3HIgEbszxgxDrFei3gisBv4HcCPwnojc4GVhx+RnjmFReS5vfGQBYkzQxHoI8y3gHFW9TVX/J7AE+LZ3ZR3vilnj2LyvxSZaNiZgYg2QiKrW9fu9YRjbnrbPzi0B4MX1+xK1S2NMDGINgddE5HURuV1EbgdeAV71rqzjTchLZ8mkPJ5fV23zpBoTIEPdCzNNRJaq6teBnwDz3Me7wCMJqK/PHy8sZcfBdpsjxJgAGaoH8n2gBUBVn1fVr6nq14AX3PcS5uq5JYxJivDM+3sTuVtjzCCGCpBxqrpx4Ivua5MG21BEUkVktYhsEJHNIvKd06iT7LRkrp0/nufX1dDc2X06H2WMGSFDBUjOIO+lDbHtYeAyVZ0PnA1cKSLnDae4gW67YBKd3T08u8Z6IcYEwVABskZE/mLgiyLy58CgkyWr49i06snu47RGQOeUZrN4Yi6Praqip9cGU43x21AB8lXgDhF5S0QecB//AdyFs2DUoNxV6dbjTED0hqq+d7oF33nhZKoaOnhlY+3pfpQx5jQNNaHQAVW9APgOsNt9fEdVz3cvbx+Uqvao6tk4984sEZE5A9ucbFmHwVw5u5gZ4zL54fJK64UY47NY74VZoaoPuY8/DHcnqtoErACuPMl7JyzrMGjBEeGvL5/O9ro264UY4zMvF5YqFJEc93ka8Glg60h89tVzSphelMn33/yY7p5hLXpnjBlBXl6OXgKsEJEPgfdxxkBeHokPjkSEv7mygp0H23liVdVIfKQxJg6xTig0bKr6IbDAq8+/4qwiLpxWwINvVnLd2aXkZqR4tStjzCkk7Ia4kSYifPuaWbR2dfP/f7/N73KMOSOFNkAAZhZnccfSyTzx3h6beNkYH4Q6QAD+97KZTMxP529/8yGdR2wdXWMSKfQBkpYS5bufm0dVQwf//NqInOQxxsQo9AECzvoxdyydxC//azevbx7y+jZjzAgZFQEC8I2rKphbms3Xn93A3kab+tCYRBg1ATImKcrDf7oQVbjnyXV0ddt4iDFeGzUBAlCen84DN85nY00z9z67gV67V8YYT42qAAFYNruYv72yglc+rOXBNz/2uxxjRjXPrkT10xcunsLOg2089IftjM9J45Yl5X6XZMyoNCoDRET4P9fP5UDLYb75wkbSkqNcv6DU77KMGXVG3SHMMSlJEX5y6yLOm5zPvc9u4LVNduu/MSNt1AYIQGpylJ/etpj5Zdn81ZMf8NIGW5jKmJE0qgMEIGNMEr+8cwkLJ+bylac/4HG7/d+YETPqAwRgbGoyj965hMtmFvF3v93EQ8srbYU7Y0aAlzOSTRCRFSLykbsuzJCTMHspNTnKv926iM8tKOWBNz7m3l9vsIvNjDlNXp6FOQrcq6rrRCQLWCsib6jqRx7uc1DJ0QgP3DifSQUZfO+Nj9nV0M5Pbl1EUVaqXyUZE2qe9UBUtVZV17nPW4EtgO/nUkWcSZl//GcL2VrbyrUPrWTN7ka/yzImlBIyBiIik3CmNzxhXZjhLuswUq6aW8JzXzqfMckRbnpkFQ+v2G6XvhszTJ4HiIhkAr8BvqqqLQPfH+6yDiNp9vhsXv7yhVw1p5j7X9/Gbb9YTV1rV0JrMCbMPA0QEUnGCY8nVPV5L/cVr6zUZB66ZQH/94/nsnpXI8sefJsX19fYWRpjYuDlWRgBfgZsUdXvebWfkSAi/Om55bzy1xcxuSCDrzy9ni88ttZ6I8YMwcseyFLgVuAyEVnvPq72cH+nbVpRJs998QK+eXUFb318kGUPvs1Tq/fY2IgxpyBB6qovXrxY16xZ43cZAGyva+Obz29k9e5G5k/I4Z+um828shy/yzLGEyKyVlUXD3e7M+JK1HhMK8rkmS+cx/dvOpt9TZ1c9/BK7nt+I/Vth/0uzZjAsAAZhIhw/YJS/nDvJdy5dDK/XrOXS/5lBT94s5L2w0f9Ls8Y31mAxCArNZlvXzOL3/+vi7l4RiEPvvkxl9y/gsfe3W2Le5szmo2BxGHdnkN899WtrN7dyIS8NL50yTRuWFRGSpLlsQmneMdALEDipKqs2FbHD96sZEN1MyXZqXzxkqncdM4EUpOjfpdnzLBYgPhEVXm7sp6HlleypuoQhVljuOvCydxyTjnZ6cl+l2dMTCxAfKaqrNrZyI9WVLJyewNpyVFuWFTGHUsnMaUw0+/yjBlUvAEyKidV9oOIcP7UfM6fms9H+1r4+cpdPPP+Xh5/r4rLZhZxx9LJXDA1n0hE/C7VmBFjPRAP1bV28cSqPTy+qoqG9iNMyk/n5iXl3LCojILMMX6XZ0wfO4QJsK7uHl7btJ8nV+9h9a5GkqPCslnF3LKk3HolJhAsQEJie10bT6/ew3Prqmnq6KYsN43rzy7l+gWlTCuysRLjDwuQkOnq7uH1zfv5zboa/rPyIL0K88qyuf7sUv5o/ngKs+wQxySOBUiI1bV28dL6ffx2fQ2balqIRoQLpxXw2XklLJs1jpz0FL9LNKOcBcgoUXmgld+ur+HF9fuoPtRJUsQ5u3PlnGI+M7vYBl+NJwIXICLyc+AaoE5V58SyjQXIJ1SVTTUt/G5TLb/btJ9d9e1EBM6ZlMfVc0u4YtY4SnPS/C7TjBJBDJCLgTbgUQuQ06OqbDvQyqsb9/Paplo+PtAGQEVxFpdWFHF5RRELynOJ2tkcE6fABQj0zcb+sgXIyNpxsI0VW+tYvqWO93c3crRXyUlP5lMzCrm0oohLZhTauIkZFrsS9QwytTCTqYWZ/PlFU2jp6uadj+tZvvUAb207yG/X7yMiMK8sh4umF3DhtAIWlOfancLGE773QETkbuBugPLy8kVVVbb4dbx6epUN1U28tbWOd7bXs2FvE70K6SlRzp2cx9JpBVw0vZAZ4zJx5rw2xmGHMOYEzZ3drNrZwMrt9fxnZT0769sBKMwawwVT81kyOY9zJ+cxtdAC5UxnhzDmBNlpyXxmtnP6F6CmqZOVlfW8s72e/9rRwIvr9wGQn5HCOZPyOHdKHksm51FRPNYGZE1MvDwL8xTwKaAAOAD8var+bLBtrAeSOKpKVUMHq3c18t6uRt7b1UD1oU4AslKTOGdSHudMymNheQ7zynJIS7FJkkazQB7CDJcFiL9qmjp53w2U1bsa2HHQOeSJRoSzSrJYMCGXhRNzWDAhl4n56XbYM4pYgJgR19h+hA/2HOKDPU2s23OIDXubaD/SA0BeRgoLJuSwoDyHBeW5zCnNJjvNZmALKxsDMSMuLyOFy88ax+VnjQOcszyVda2sq2pygmVvE8u31vW1n5SfzpzSbOa6j9kWKqOe9UDMaWnu6GZ9dRObaprZWN3Mxppmapo6+96f6IbKPAuVQLMeiPFFdnoyl8wo5JIZhX2vNbYfYWNNc1+orN/TxCsf1va9PyEvjbOKx1JRMpZZJVlUFI+lPC/dJlYKIQsQM+LyMlJOGiqbapweyke1LWytbeHNLQc4tm55ekqUmcVOmMwqyaKiZCwzi7MYm2q9lSCzQxjjm84jPVTWtbKltoUtta1s3e/8bO7s7mtTlptGRfFYZozLZPq4TKYXZTG1MNNOK48wO4QxoZOWEmVemXOdyTGqyv6WLrbWtjo9lf2tbK1t4a1tdRx1uysiTrDMKMpimhsq04symVaUScYY+5NOJPu2TaCICCXZaZRkp3FpRVHf6909vVQ1tFN5oI3KOvdxoJV3Kus50m994tKcNKYVZfYFyuSCDKYUZlKQmWLXrXjAAsSEQnI0wrSiLKYVZXFVv9eP9vSyp7GDyro2truhUlnXxqqdDRw++kmwZI1JYnJhhhMoBZlMLsxgSkEGkwoyyLReS9zsmzOhlhSNMKUwkymFmXxm9iev9/Qq+5o62Vnfzq6Dbeyqb2dnfTtrdh/ipQ376D/0V5Q1himFGUwuyGRKgRMykwszKMtNY0ySjbUMxgLEjErRiDAhL50JeenHnQ0CZ0b8qoYOdtW3sbO+nZ0H29lV387rm/fT2H6kr50IlIxNpTw/nYl5GZTnp1Oel85E96dN2mQBYs5AqcnOKeOZxVknvNfUcYRd9U6gVDV0sLexg6rGDpZvraO+7fBxbcemJjExP4PyvHQ3ZNL7QqYkO+2MuKPZAsSYfnLSU1hQnsKC8twT3us4cpQ9jR1UNXSwp6HDed7YweZ9zby+eX/fWSKA5KhQmpNGaW4aZTnpzs/cNEpz0ijLS2dc1hiSouGfJc4CxJgYpackUVE8lorisSe8d7Snl9rmLvY0dvSFTPWhDqoPdZ609xKNCCXZqU6g5KY74eKGTFlOOiU5qSSHIGA8DRARuRL4ARAFfqqq3/Vyf8b4JSka6RtzWXqS97u6e6hp6qTmUCfVhzqpaXLCpeZQJyu313Ogteu4gd2IwLixqZTlpjE+xzmtPT4nleKxqYzPSaM4O5X8DP9PTXsWICISBR4GPg1UA++LyEuq+pFX+zQmqFKTo32TYZ/MkaO91DZ/EjDVTZ19PZi1VYc40FJLd8/xV42nJEUoyT4+VMZnp1KS7T7PSSM3PdnTkPGyB7IE2K6qOwFE5GngOsACxJgBUpIiTMzPYGJ+xknf7+1V6tsPU9vURW1zF7XNne7PLmqbOlm9q5EDLV3HjcMAjHFDxrk4L5WSnFSKs9MoHusEz7jsMRRkxL/aoZcBUgrs7fd7NXCuh/szZtSKRISirFSKslKZP+HkbXp6lYa2w+xzQ2Vg0Kza2cCB1sP0DAiZ5Gj8PRTfB1H7L+sAHBaRTX7WM0wFQL3fRQxT2GoOW70QzppnxrORlwFSA/TPyjL3teOo6iPAIwAisiaeOwL9ErZ6IXw1h61eCG/N8Wzn5Xmi94HpIjJZRFKAm4GXPNyfMSbBPOuBqOpREfkr4HWc07g/V9XNXu3PGJN4no6BqOqrwKvD2OQRr2rxSNjqhfDVHLZ64QyqOVAzkhljwiX418oaYwIr4QEiIleKyDYR2S4i3zjJ+2NE5Bn3/ffcBbp9FUPNXxORj0TkQxFZLiIT/ahzQE2D1tyv3Z+IiIqIr2cNYqlXRG50v+fNIvJkoms8ST1D/V2Ui8gKEfnA/du42o86+9XzcxGpO9WlEuL4ofvv+VBEFg75oaqasAfOYOoOYAqQAmwAZg1o85fAv7nPbwaeSWSNcdZ8KZDuPv9SGGp222UBbwOrgMVBrheYDnwA5Lq/FwX9O8YZV/iS+3wWsNvnmi8GFgKbTvH+1cDvAAHOA94b6jMT3QPpu7xdVY8Axy5v7+864Ffu8+eAy8XfO4aGrFlVV6hqh/vrKpxrXvwUy/cM8E/APwNdiSzuJGKp9y+Ah1X1EICq1uGvWGpW4Nitu9nAvgTWdwJVfRtoHKTJdcCj6lgF5IhIyWCfmegAOdnl7aWnaqOqR4FmID8h1Z1cLDX3dxdOivtpyJrd7ukEVX0lkYWdQizf8QxghoisFJFV7p3efoql5n8APi8i1ThnI7+cmNLiNty/df8vZR9NROTzwGLgEr9rGYyIRIDvAbf7XMpwJOEcxnwKp4f3tojMVdUmX6sa3C3AL1X1ARE5H3hMROaoau9QG4ZFonsgsVze3tdGRJJwun4NCanu5GK6JF9ErgC+BVyrqocHvp9gQ9WcBcwB3hKR3TjHuy/5OJAay3dcDbykqt2qugv4GCdQ/BJLzXcBvwZQ1XeBVJz7ZIIqpr/14yR4ECcJ2AlM5pOBp9kD2tzD8YOov/Z54CmWmhfgDKhN97PW4dQ8oP1b+DuIGst3fCXwK/d5AU5XOz/gNf8OuN19fhbOGIj4/LcxiVMPon6W4wdRVw/5eT78A67G+d9jB/At97V/xPmfG5yUfhbYDqwGpvj5hcdY85vAAWC9+3gp6DUPaOtrgMT4HQvOYddHwEbg5qB/xzhnXla64bIeWOZzvU8BtUA3To/uLuCLwBf7fccPu/+ejbH8TdiVqMaYuNmVqMaYuFmAGGPiZgFijImbBYgxJm4WIMaYuFmAnCFEpEdE1ovIJhF5VkTST+OzPiUiL7vPrx3ibt8cEfnLfr+PF5Hn4t23CRYLkDNHp6qerapzgCM45//7uLdyD/vvQVVf0sFXHMzBucP6WPt9qnrDcPdjgskC5Mz0DjBNRCa581k8CmwCJojIMhF5V0TWuT2VTOib+2KriKwDPnfsg0TkdhH5kft8nIi8ICIb3McFwHeBqW7v5353n5vc9qki8gsR2ejOmXFpv898XkReE5FKEfmXxH49JlYWIGcY9/6iq3CuNATnfpJ/VdXZQDvwd8AVqroQWAN8TURSgX8H/ghYBBSf4uN/CPyHqs7HmXdiM/ANYIfb+/n6gPb3AKqqc3FuPPuVuy+As4GbgLnATSJyiuWUjJ8sQM4caSKyHicU9gA/c1+vUmfuB3Duf5gFrHTb3gZMBCqAXapaqc6ly4+fYh+XAT8GUNUeVW0eoqYLj32Wqm4FqnBu2wdYrqrNqtqFc/m677O8mRPZ7fxnjk5VPbv/C+48Te39XwLeUNVbBrQ7brsE6X9Hcw/2txpI1gMx/a0ClorINAARyRCRGcBWYJKITHXb3XKK7ZfjTOmIiERFJBtoxZk+4GTeAf7MbT8DKAe2jcQ/xCSGBYjpo6oHcSYZekpEPgTeBSrcw4i7gVfcQdRTTSf4FeBSEdkIrMWZI7QB55Bok4jcP6D9vwIRt/0zOLe++z2XihkGuxvXGBM364EYY+JmAWKMiZsFiDEmbhYgxpi4WYAYY+JmAWKMiZsFiDEmbhYgxpi4/TdmSF1cK8E9EAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot sample cost vs y_hat (prediction), for y (truth) = 1\n", "y_hat = np.linspace(0, 1, 1000)\n", "cost = -np.log(y_hat)\n", "plt.figure(figsize=(4,3))\n", "plt.plot(y_hat, cost)\n", "plt.xlabel('Prediction'); plt.ylabel('Cost')\n", "plt.xlim(0, 1); plt.ylim(0, 7); plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/markbaum/Python/python/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: divide by zero encountered in log\n", " This is separate from the ipykernel package so we can avoid doing imports until\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARAAAADUCAYAAABUBHwDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGZJJREFUeJzt3XlwXeWZ5/Hv40XWLlmW5UW2vBsb7OANMEtIWJstkIZMgAATGDquTjNdYaA6FZqe6plJVTodpjMhBc0MWxIITSAE0hQhJGCgCQTb2MZgYxvvK7Z2ydqXq2f+uMdGCNm6urrnLtbvU3XLdznnvI9ubn685z3nvMfcHRGReIxIdQEikrkUICISNwWIiMRNASIicVOAiEjcFCAiErfQAsTMTjGzDb0eR8zszrDaE5Hks2ScB2JmI4GDwFnuvjf0BkUkKZK1C3MRsFPhIXJySVaA3AA8naS2RCRJQt+FMbMs4BPgNHev7OfzFcAKgLy8vKXz5s0LtR4R+bx169bVuPv4wa43Koxi+rgcWN9feAC4+8PAwwDLli3ztWvXJqEkEQH4+HATK55cC+sujGt4IRm7MDei3ReRtNTWFWFvbWvc64caIGaWB1wCPB9mOyISn0jP0IYwQt2FcfcWYFyYbYhI/IYaIDoTVWQYU4CISNwUICISt8gQT+NQgIgMY5GeniGtrwARGcYiQ8sPBYjIcKYeiIjErVuDqCISLx2FEZG4KUBEJG4KEBGJmwJEROKmE8lEJG7qgYhI3LoiaRwgZlZsZs+Z2VYz22JmZ4fZnogMTmf30E4kC3tKw/uBV9z9a8HcqLkhtycig5C2AWJmRcD5wK0A7t4JdIbVnogMXkd3hFEjLO71w9yFmQFUAz8zs/fN7NFgikMRSROd3T1kjYo/BsIMkFHAEuAhd18MtADf67uQma0ws7Vmtra6ujrEckSkr85I+gbIAeCAu68OXj9HNFA+w90fdvdl7r5s/PhB35ZCRIags7uHrJFpGCDufhjYb2anBG9dBGwOqz0RGbyh7sKEfRTmb4GngiMwu4DbQm5PRAahY4i7MGHf1mEDsCzMNkQkfmm7CyMi6a+zu4cxaTqIKiJpLp0P44pImuuM9DBm1Mi411eAiAxj6oGISNw0iCoicUvnM1FFJM21d0V0FEZE4tPS0U3emPhPB1OAiAxT7k5rZ4S8MToKIyKD1BnpobvHyc1SD0REBqm1IwJAXpZ6ICIySC2d3QDkagxERAartfNoD0QBIiKD1NJxtAeiXRgRGaRE9EBCnQ/EzPYATUAE6HZ3zQ0ikiaaj/ZAhjCIGvaMZAAXuHtNEtoRkUE40tYFQFHO6Li3oV0YkWGqMQiQwjQOEAf+aGbrzGxFfwvotg4iqdHY1oUZFKTxYdzz3H0JcDlwh5md33cB3dZBJDUa27oozB7NiDS9Mx3ufjD4twp4ATgzzPZEJHaNbV0U58a/+wIhBoiZ5ZlZwdHnwKXAprDaE5HBaWjtGtIAKoR7FGYC8IKZHW3n39z9lRDbE5FBaGhL4wBx913A6WFtX0SGpqapg1mlQ7vfvQ7jigxD7k5NcwelBWOGtB0FiMgw1NzRTUd3D6X5WUPajgJEZBiqbuoAoDRfPRARGaSa5k5AASIicahpVg9EROJ0LEAKNAYiIoN0sKGNrJEjKM1TD0REBulgfRuTi7OHdB0MKEBEhqUD9W2Uj80Z8nYUICLD0IH6NqYU5w55OwoQkWGmvStCTXMHU9QDEZHBOtjQBqBdGBEZvAP10QCZMla7MCIySLurmwGYNi4DAsTMRprZ+2b2UthticjAdlQ3U5A9irIhXokLyemBfAfYkoR2RCQG2yubmV2WTzDZ15CEGiBmNgW4Eng0zHZEJHY7q5uZU5afkG2F3QP5CfBdoCfkdkQkBvUtndQ0dzKnrCAh2wtzUuWrgCp3XzfAcrovjEiSbKtsAmB2BvRAzgWuDu6P+yvgQjP7Zd+FdF8YkeTZeLARgNPKCxOyvdACxN3vcfcp7j4duAF43d1vDqs9ERnYpoONTCzMpqwgOyHbiylAzOzJWN4TkfT24cFGFk4pStj2Yu2BnNb7hZmNBJbG2oi7v+nuVw2mMBFJrKb2LnbXtLCwPEkBYmb3mFkT8AUzOxI8moAq4N8TVoWIhG79vgbcYXFFccK2ecIAcfd/cvcC4D53LwweBe4+zt3vSVgVIhK6NbtrGTXCWDptbMK2GesuzEvB/W0xs5vN7MdmNi1hVYhI6FbvqmNBeRG5WYm7IWWsAfIQ0GpmpwN3AzuBJxJWhYiEqr0rwgcHGjhrRklCtxtrgHS7uwPXAA+4+4NAYk5lE5HQrdtbT1fEOWtmYgMk1r5Mk5ndA9wCfNHMRgBDu623iCTN61uryBo1guUzxyV0u7H2QK4HOoD/4u6HgSnAfQmtRERC8/rWKs6ZNS6h4x8QY4AEofEUUBRc49Lu7hoDEckAu6qb2V3TwkXzyhK+7VjPRP06sAb4T8DXgdVm9rWEVyMiCff61ioALgghQGLtz9wLnOHuVQBmNh54DXgu4RWJSEK99OEh5k8qTMgcqH3FOgYy4mh4BGoHsa6IpMiemhY27G/gq4smh7L9WHsgr5jZH4Cng9fXAy+HUpGIJMyLH3yCGXzl9BQEiJnNBia4+9+Z2bXAecFH7xIdVBWRNOXu/Pb9g5w5vYTJxUO/B0x/BtoN+QlwJCjmeXe/y93vAl4IPhORNPXuzlp21bTwtaVTQmtjoACZ4O4b+74ZvDf9RCuaWbaZrTGzD8zsIzP7n0OoU0QG6clVeynOHR3a7gsMHCAnuu53oD5RB3Chu58OLAIuM7PlgylOROJzqLGNP26u5PplU8kePTK0dgYKkLVm9q2+b5rZXwEnnCzZo5qDl6ODh8dVpYgMyi9X7aXHnZvOCvei+YGOwtwJvGBmN/FpYCwDsoC/HGjjwcxl64DZwIPuvnoItYpIDBrbunjiz3u5YsEkKhJw+8oTOWGAuHslcI6ZXQAsCN7+nbu/HsvG3T0CLDKzYqJBtMDdN/VexsxWACsAKioqBlu/iPTxxJ/30NTRzR0XzA69rZjOA3H3N4A34m3E3RvM7A3gMmBTn88eBh4GWLZsmXZxRIbgSHsXj72zm4vnl3Hq5MTcuuFEwryx1Pig54GZ5QCXAFvDak9E4KE3d9LQ2sWdF89NSnuJvbb3syYBvwjGQUYAz7r7SyG2JzKsHWxo47G3d3Pt4nIWJHDm9RMJLUDc/UNgcVjbF5HPuu+VrRhw91+ckrQ2dUGcyEng7e01/HbDJ6w4fyblIZ223h8FiEiGa+uM8PcvbGRGaV5Sjrz0FuYYiIgkwU9WbmNfXStPf2t5qGed9kc9EJEM9u7OWh5+axc3nDGVs2cldsLkWChARDJUfUsn/+2ZDcwYl8d/v+rUlNSgABHJQO7Od3/zIXUtnfz0xsXkjUnNaIQCRCQDPfD6Dl7dXMn3Lp+XtHM++qMAEckwr2w6zL+8uo1rF5dz27nTU1qLAkQkg2w62Mhdz25g0dRifnDtQswspfUoQEQyxK7qZr75+BrG5mbx8C1Lk37Itj8KEJEMcKixjVseWwPAk7efSVlhdooritKJZCJp7lBjGzc9sprGti5+tWI5M8fnp7qkYxQgImlsf10r33h0FfUtXfzstjNSesSlPwoQkTS1q7qZmx5dTWtnhKf+6ixOn3qiOc5TQwEikobe21PHiifWYmY8/a3lSZldLB5hzkg21czeMLPNwX1hvhNWWyInk3/fcJCbHllNcW4Wz3/7nLQNDwi3B9IN3O3u682sAFhnZq+6++YQ2xTJWD09zk9WbuenK7dz5owS/t/NSxmbl5Xqsk4ozBnJDgGHgudNZrYFKAcUICJ91LV0cuczG3hrWzXXLZnCD65dwJhRqT/PYyBJGQMxs+lEpzf83H1hdFsHGe7W76vnjqfWU9vcyQ/+ciE3njk15WeYxir0ADGzfOA3wJ3ufqTv57qtgwxX3ZEeHnpzJ/ev3M6k4mx+8+1zWDglvQ7TDiTUADGz0UTD4yl3fz7MtkQyyc7qZu569gM+2N/A1adP5vvXLKAod3Sqyxq00ALEon2wx4At7v7jsNoRySSRHufnf97Dj17ZSk7WSB74xmKu+sLkVJcVtzB7IOcCtwAbzWxD8N7fu/vLIbYpkrY+2N/Avb/dyKaDR7hwXhk/vHZh2lzTEq8wj8K8DWTGSJBIiBrbuvjff/iYX67ey/j8MTzwjcVcuXBSxgyUnojORBUJSVekh1+9t5/7X9tGXUsn3zx7OndfOpeC7Mwb6zgeBYhIgrk7r22p4p9+v4Vd1S2cOaOEn992atpdCJcIChCRBFq3t44fvfIxq3fXMXN8Ho/852VcPL/spNhd6Y8CRCQB3ttTx/2vbeftHTWU5mfx/a8u4IYzpjJ65Mk9Z5cCRGQI1uyu4/6V23hnRy2l+Vnce8V8blpeQW7W8Pi/1vD4K0USKNLjvLr5MI/8aTfr9tZTmj+Gf7hyPjedNY2crPS/fiWRFCAiMWrp6ObXa/fz+Dt72FfXytSSHP7xK6dywxkVwy44jlKAiAxgT00LT6/Zx9Nr9nGkvZslFcXcc/k8Lj1tIiNHnJyDo7FSgIj0oyvSw2ubK3lq9T7e3lHDyBHGX5w2gdvPm8nSaWNTXV7aUICI9LKvtpVn1+7nmbX7qW7qYHJRNndfMpevnzGVCRl+2nkYFCAy7DW2dvG7jYd4fv0B1u6tZ4TBBaeUcdPyCr40t2zY76aciAJEhqWuSA//8XE1z79/gNe2VNHZ3cPssny+e9kpfHVROZOLc1JdYkZQgMiw0RXp4c87a/n9xkP84aPD1Ld2UZKXxTfOrOC6JVNYUF540p4xGhYFiJzUOrt7eGdnDS9/eIg/bq6ksa2LvKyRXDR/Atcsmsz5c8ef9GeLhinMCYUeB64Cqtx9QVjtiPTV3NHNn7ZV8+qWSl7bXMmR9m4Kxozi4lMncPmCiZw/d3xa3Jj6ZBBmD+TnwAPAEyG2IQJEbwG5ckslK7dWsWpXLV0RpzA7GhpXLpzEeXNKM2KW80wT5oRCbwWzsYskXFekhw37G1i5pYrXt1ayrbIZgFnj87jt3BlcOK+MZdPGMkq7J6HSGIhkBHdnZ3ULb2+v5u0dNazaVUdzRzejRhhnzijh+jMquGheGdNL81Jd6rCS8gDRfWHkeGqbO3hnZ200NLbX8EljOwAVJblcvWgyX5xdyrlzSik8iWb4yjQpDxDdF0aOqmnu4L3ddawOHlsORW8jVJg9inNnl3LHhaV8cfZ4KsblprhSOSrlASLD1+HGdlbvrmX17jrW7K5jR1V0HCNn9EiWThvL3ZfM5Ytzx7OwvEhng6apMA/jPg18GSg1swPAP7r7Y2G1J+mtp8fZXtXM+/vqWbe3njV76thb2wpAwZhRLJs+luuWTOGsmSUsmFxE1igNfmaCMI/C3BjWtiX9NbR28v6+Bt7fV8/6fQ18sL+Bpo5uAMbmjuaM6SXcsnway2eOY/6kQvUwMpR2YWTIuiI9bKtsYsP+BtbvjYbGrpoWAEYYzJtYyNWLJrOkYiyLK4qZUZqnU8ZPEgoQGZSjYbHxQCMbDzay6WAjWw430dndA8C4vCwWVxRz3dIpLKkYyxemFJE3Rj+zk5X+l5XjGigsCsaMYkF5EbeeM50F5UWcPqWIipJc9S6GEQWIAFDX0smWQ0fYcugIWw83sfXwEbZVNn8aFtmjWDD507BYWF7EtJJcRmjsYlhTgAwznd097KppZuuhJrYcPsKWQ01sPXSEqqaOY8uMLxjDvIkFCgsZkALkJBXpcQ7Ut7K9spntVc1sr2xiy+EmdlQ10RWJnq+XNXIEs8vyOW9OKadOKmTexELmTSqgNH9MiquXTKEAyXBdkR721rayo6qJ7ZXN7KhuZntlMzurm+kIdj8AJhSOYd7EQr40dzzzJxUwb2IhM8fnaS4MGRIFSIZo64ywp7aFnUFA7KhqZntVE7trWo71KADKi3OYMyGfc2ePY05ZAbMn5DNrfD5FObpeRBJPAZJGuiI9HKhvY3dNM7uqW9hT28LumhZ2V7ccu5AMwAymleQyu6yAi+ZPYE5ZPrPLokGhQ6aSTPq1JZm7c/hIO7urW9hV08KemiAkalrYV9dKd8+nvYminNHMKM1j+cxxzCjNY8b4PGaU5jFrfL5m1JK0oAAJQXtXhAP1beyva2Vf8Nhb23rsdVtX5Niy2aNHMH1cHvMmFXD5wonMKM1nRmkeM0vzGJuXlcK/QmRgCpA4uDu1LZ3sqwtCobaVvUE47K9r5fCRdrzXxAQ5o0dSUZLL1JJczptTeiwgppfmMbEwW4dHJWMpQPrh7jS2dXGgvo2DDW0cDP7d3yskWjojn1lnQuEYKkpyOXvWOKaV5FExLoeKklwqSvIozc/S2ZlyUhqWAdLT41Q3d/QJiNZjQXGwvu1zAZE9egRTxuYyrSSX5TPHMW1cbhAQ0Z6FxiRkOAo1QMzsMuB+YCTwqLv/MMz2jmrvinC4sZ1PGtv4pKH904BoaONAfRuHGtrpjPR8Zp2inNGUF+cwbVwe58wqZcrYHMqLcygP/i3JUy9CpK8wJxQaCTwIXAIcAN4zsxfdffNQttvS0c2hxnYON7ZzqLEt+u+Ro6/bOdzYRn1r1+fWG18whvLiHBaUF3HZaROjATE2h/LiXCYXZ1OgeTVFBi3MHsiZwA533wVgZr8CrgGOGyARd7ZVNh0Lgk+D4tPAONLe/bn1SvKymFiYzeSibJZOK2ZSUQ4TC7OZWJTN5OIcJhVlaxdDJARhBkg5sL/X6wPAWSdaYfMnR7j0/7x17LUZlOaPYVJRNtPG5bJ8ZgkTi6KBMLEom0lF2UwoVDiIpErKB1F739YB6Nj7z1dt6v35nqRXNCilQE2qixikTKs50+qFzKz5lHhWCjNADgJTe72eErz3Gb1v62Bma919WYg1JVSm1QuZV3Om1QuZW3M864V5KeZ7wBwzm2FmWcANwIshticiSRbmrOzdZvZfgT8QPYz7uLt/FFZ7IpJ8oY6BuPvLwMuDWOXhsGoJSabVC5lXc6bVC8OoZnPX3SRFJD6ajkpE4pb0ADGzy8zsYzPbYWbf6+fzMWb2TPD5ajObnuwa+6lpoJrvMrPNZvahma00s2mpqLNPTSesuddy15mZm1lKjxrEUq+ZfT34nj8ys39Ldo391DPQ76LCzN4ws/eD38YVqaizVz2Pm1mVmW06zudmZj8N/p4PzWzJgBt196Q9iA6m7gRmAlnAB8CpfZb5G+D/Bs9vAJ5JZo1x1nwBkBs8/3Ym1BwsVwC8BawClqVzvcAc4H1gbPC6LN2/Y6LjCt8Onp8K7ElxzecDS4BNx/n8CuD3gAHLgdUDbTPZPZBjp7e7eydw9PT23q4BfhE8fw64yFJ7FduANbv7G+7eGrxcRfScl1SK5XsG+D7wz0B7P58lUyz1fgt40N3rAdy9Ksk19hVLzQ4UBs+LgE+SWN/nuPtbQN0JFrkGeMKjVgHFZjbpRNtMdoD0d3p7+fGWcfduoBEYl5Tq+hdLzb3dTjTFU2nAmoPu6VR3/10yCzuOWL7jucBcM3vHzFYFV3qnUiw1/w/gZjM7QPRo5N8mp7S4Dfa3nvpT2U8mZnYzsAz4UqprOREzGwH8GLg1xaUMxiiiuzFfJtrDe8vMFrp7Q0qrOrEbgZ+7+7+Y2dnAk2a2wN17BloxUyS7BxLL6e3HljGzUUS7frVJqa5/MZ2Sb2YXA/cCV7t7R9/Pk2ygmguABcCbZraH6P7uiykcSI3lOz4AvOjuXe6+G9hGNFBSJZaabweeBXD3d4FsotfJpKuYfuufkeRBnFHALmAGnw48ndZnmTv47CDqsykeeIql5sVEB9TmpLLWwdTcZ/k3Se0gaizf8WXAL4LnpUS72uPSvObfA7cGz+cTHQOxFP82pnP8QdQr+ewg6poBt5eCP+AKov/12AncG7z3v4j+lxuiKf1rYAewBpiZyi88xppfAyqBDcHjxXSvuc+yKQ2QGL9jI7rbtRnYCNyQ7t8x0SMv7wThsgG4NMX1Pg0cArqI9uhuB/4a+Ote3/GDwd+zMZbfhM5EFZG46UxUEYmbAkRE4qYAEZG4KUBEJG4KEBGJmwJkmDCziJltMLNNZvZrM8sdwra+bGYvBc+vHuBq32Iz+5teryeb2XPxti3pRQEyfLS5+yJ3XwB0Ej3+f0xwKfegfw/u/qKf+I6DxUSvsD66/Cfu/rXBtiPpSQEyPP0JmG1m04P5LJ4ANgFTzexSM3vXzNYHPZV8ODb3xVYzWw9ce3RDZnarmT0QPJ9gZi+Y2QfB4xzgh8CsoPdzX9DmpmD5bDP7mZltDObMuKDXNp83s1fMbLuZ/Si5X4/ESgEyzATXF11O9ExDiF5P8q/ufhrQAvwDcLG7LwHWAneZWTbwCPAVYCkw8Tib/ynwH+5+OtF5Jz4CvgfsDHo/f9dn+TsAd/eFRC88+0XQFsAi4HpgIXC9mU1F0o4CZPjIMbMNRENhH/BY8P5ej879ANHrH04F3gmW/SYwDZgH7Hb37R49dfmXx2njQuAhAHePuHvjADWdd3Rb7r4V2Ev0sn2Ale7e6O7tRE9fT/ksb/J5upx/+Ghz90W93wjmaWrp/Rbwqrvf2Ge5z6yXJL2vaI6g32paUg9EelsFnGtmswHMLM/M5gJbgelmNitY7sbjrL+S6JSOmNlIMysCmohOH9CfPwE3BcvPBSqAjxPxh0hyKEDkGHevJjrJ0NNm9iHwLjAv2I1YAfwuGEQ93nSC3wEuMLONwDqic4TWEt0l2mRm9/VZ/l+BEcHyzxC99D3Vc6nIIOhqXBGJm3ogIhI3BYiIxE0BIiJxU4CISNwUICISNwWIiMRNASIicVOAiEjc/j+PcH4xxz8S/gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot sample cost vs y_hat (prediction), for y (truth) = 0\n", "y_hat = np.linspace(0, 1, 1000)\n", "cost = -np.log(1 - y_hat)\n", "plt.figure(figsize=(4,3))\n", "plt.plot(y_hat, cost)\n", "plt.xlabel('Prediction'); plt.ylabel('Cost')\n", "plt.xlim(0, 1); plt.ylim(0, 7); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **3 그레디언트 하강을 활용한 로지스틱 회귀**\n", "1. 단볼록이 아닌, 비볼록 형태의 데이터에 대한 로지스틱 회귀 최적값을 예측한다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **01 그레디언트 하강기법의 로지스틱 함수 정의**\n", "로지스틱 회귀는 나이브 베이즈 분류기처럼 **확률 기반 분류기이다**" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# 현재의 가중치 값을 사용하여 예측값을 계산하는 함수\n", "def compute_prediction(X, weights):\n", " z = np.dot(X, weights)\n", " predictions = sigmoid(z)\n", " return predictions \n", "\n", "# Gradient 하강 기법을 단계저긍로 정의하여 가중치를 업데이트 한다\n", "def update_weights_gd(X_train, y_train, weights, learning_rate):\n", " predictions = compute_prediction(X_train, weights)\n", " weights_delta = np.dot(X_train.T, y_train - predictions)\n", " m = y_train.shape[0]\n", " weights += learning_rate / float(m) * weights_delta\n", " return weights # updated weights(numpy.ndarray)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# 비용을 계산하는 함수를 계산한다\n", "def compute_cost(X, y, weights):\n", " predictions = compute_prediction(X, weights)\n", " cost = np.mean(-y * np.log(predictions) - (1-y) * np.log(1-predictions))\n", " return cost # float\n", "\n", "# 로지스틱 회귀 모델을 학습한다\n", "def train_logistic_regression(X_train, y_train, max_iter, learning_rate, fit_intercept=False):\n", " if fit_intercept:\n", " intercept = np.ones((X_train.shape[0], 1))\n", " # .hstack() 행의 수가 같은 두 개 이상의 배열을 옆으로 연결\n", " X_train_np = np.hstack((intercept, X_train))\n", " weights = np.zeros(X_train_np.shape[1])\n", " for iteration in range(max_iter):\n", " weights = update_weights_gd(X_train_np, y_train, weights, learning_rate)\n", " if iteration % 1000 == 0: # 1000번 학습을 반복한다\n", " print(\"{:,}th Logistic Cost : {:.5f}\".format(\n", " iteration, compute_cost(X_train_np, y_train, weights)))\n", " return weights" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# 학습모델을 이용하여 새로운 데이터의 결과를 예측하는 함수\n", "def predict(X, weights):\n", " if X.shape[1] == weights.shape[0] - 1:\n", " intercept = np.ones((X.shape[0], 1))\n", " X = np.hstack((intercept, X))\n", " return compute_prediction(X, weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **02 예제 데이터를 활용하여 모델을 학습한다**\n", "1. **절편값이 포함된** 가중치 함수를 기반으로 학습한다\n", "1. **학습률은 0.1**, 로지스틱 회귀 모델을 **1,000번 반복하여** 학습한다" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th Logistic Cost : 0.57440\n", "1,000th Logistic Cost : 0.00395\n", "2,000th Logistic Cost : 0.00202\n", "3,000th Logistic Cost : 0.00136\n", "4,000th Logistic Cost : 0.00103\n", "5,000th Logistic Cost : 0.00082\n", "6,000th Logistic Cost : 0.00069\n", "7,000th Logistic Cost : 0.00059\n", "8,000th Logistic Cost : 0.00052\n", "9,000th Logistic Cost : 0.00046\n" ] } ], "source": [ "# iterator를 반복할수록 학습의 Cost 값이 줄어듬을 알 수 있다\n", "X_train = np.array([[6, 7],[2, 4],[3, 6],[4, 7],[1, 6],\n", " [5, 2],[2, 0],[6, 3],[4, 1],[7, 2]])\n", "y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n", "weights = train_logistic_regression(X_train, y_train, \n", " max_iter = 10000, \n", " learning_rate = 0.1, \n", " fit_intercept = True)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([9.99999394e-01, 8.71880199e-04, 9.96881227e-01, 3.66361408e-03])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test = np.array([[6, 1],[1, 3],[3, 1],[4, 5]])\n", "predictions = predict(X_test, weights)\n", "predictions" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEKCAYAAAARnO4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAF5JJREFUeJzt3X2QlNWd9vHvj2FmmkFkBAaMaMAkJJSxWF46GiOJRERRLNi4ZpVaN6sxzsbEfcYyu66SfdZKKlZlYyrRsoIVfCNZFQthdS0irO6uuBoN2IBRZEyyAibwEBlFBIZ5Y/g9f3RreBlgZpjTZ7rP9anqmpnTN32urqGuvuf03fdt7o6IiJS/AbEDiIhIcajwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRAyMHeBAI0aM8LFjx8aOISJSMtasWfOOu9d1Z9t+Vfhjx44ll8vFjiEiUjLM7K3ubqslHRGRRKjwRUQSocIXEUmECl9EJBEqfBGRRAQrfDP7lJm9csBtl5ndGGo+kRA6O+EnP4Hx42H0aPj61+GPf4ydKm0vvQQzZsDJJ8N558HKlbETlQ4rxhWvzKwC2Aqc7e5HPIQom826DsuU/uSaa2DxYti7N/9zZSWMGAEbNkBtbdxsKVq5Ei65BFpa/jRWU5P/Hc2aFS1WVGa2xt2z3dm2WEs604E3j1b2Iv3N5s3w6KN/KnuAjg7YuRPuuy9arKTddNPBZQ/538+NWjvolmIV/pXAoiLNJdIn1q6FqqrDx1ta4Lnnip9HYP36rsfffDO//CZHF7zwzawKmA08doT7680sZ2a5pqam0HFEuu2jH+26RCor4ROfKH4egZEjux6vrYWKiuJmKUXF2MO/GFjr7m93dae7L3D3rLtn6+q6dToIkaKYMgXGjYOBh5yApLISbrghTqbUzZuXX7M/UE0N/P3fx8lTaopR+HPRco6UIDN4+mmYPj2/tJPJwNix8ItfwMc/Hjtdmq6/Pl/6J5yQL/qamvz6/S23xE5WGoIepWNmg4HfAx9z9/ePtb2O0pH+6r338m8OnnJK/oVA4mprg7ffzi/xZDKx08TVk6N0gp4t092bgeEh5xAphpNOyt+kf6iuzr/HIj2jT9qKiCRChS8ikggVvohIIlT4IiKJUOGLiCRChS8ikggVvohIIlT4IiKJUOGLiCRChS8ikggVvohIIlT4IiKJUOGLiCRChS8ikggVvohIIlT4IiKJUOGLiCRChS8ikggVvohIIoIWvpnVmtkSM3vDzBrN7JxQc23bBv/wD3D22XDVVbBuXaiZpDv27YOFC2HaNPjiF+Ghh6CzM3YqkbQFvYg5cBewwt0vN7MqoCbEJJs3w5QpsGcPtLdDLgePPw6LF8OsWSFmlKNxhy99CZ59Fpqb82Mvvwz//u/534lZ3HwiqQq2h29mQ4EvAPcDuHu7u+8MMdc//RO8/36+7AH274e9e6G+Pl8+UlzPP39w2UP+++XL88UvInGEXNI5HWgCHjSzdWZ2n5kNDjHRf/5n18sFO3bkl3qkuJ59Nv+Ce6i2tvx9IhJHyMIfCEwG7nH3SUAzcMuhG5lZvZnlzCzX1NTUq4mGDet6fP9+GDKkVw8px2HECMhkDh+vrs7fJyJxhCz8LcAWd19V+HkJ+ReAg7j7AnfPunu2rq6uVxN961sw+JC/HaqrYfZsFX4MV14JA7r4nzVgAFx+efHziEhesMJ39z8CfzCzTxWGpgMbQsz11a/C17+eL/mhQ/N7l5//PNx/f4jZ5FiGD4df/CK/Nz9kSP42ciSsWJH//YhIHOYB39U0s4nAfUAVsBG4xt3fO9L22WzWc7lcr+d7911Yvx4++lE4/fReP4z0kX37YM2a/FE5U6ZARUXsRCLlx8zWuHu2O9sGPSzT3V8BuhWkLwwfDuedV6zZ5FgGDsx/LkJE+gd90lZEJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSEbTwzWyzmb1mZq+YWe+vTi4S2caN8PjjsVOIHJ9i7OF/0d0ndveq6iL90R13wHXXwf79sZOI9N7A2AFE+qu9e+H22/NfH3kEWlvzpT90KFx9NUyYEDuhSM+ELnwHnjYzB37q7gsCzyfSZ6qq4Pnn87cPPPggDBsGX/1qvFwivRV6SWequ08GLga+aWZfOHQDM6s3s5yZ5ZqamgLHEem+gQNh5Uq4+OI/jVVWwm9+A2eeGS2WSK8FLXx331r4uh14HDiri20WuHvW3bN1dXUh44j0mBmsXZsv/5Ejob0dfv/72KlEeidY4ZvZYDMb8sH3wIXA+lDziYSwbx+ccw6sWwdvvgn19bBnT+xUIr0Tcg1/FPC4mX0wzyPuviLgfCJ9rrLy4MMxf/rTeFlEjlewwnf3jcCfhXp8ERHpGX3SVkQkESp8EZFEqPBFRBKhwhcRSYQKX0QkESp8EZFEqPBFRBKhwhcRSYQKX0QkESp8EZFEqPBFRBKhwhcRSYQKX0QkESp8EZFEqPBFRBKhwhcRSYQKX0QkESp8EZFEqPBFRBIR8iLmAJhZBZADtrr7paHnKwdr18Izz8BJJ8Hll8OwYbETiUg5CF74QAPQCJxYhLlKmjtccw089hi0t0NVFdx0Ezz5JJx/fux0IlLqgi7pmNmpwCzgvpDzlIsnnoAlS2DvXti3L/+1uRn+4i/yLwAiIscj9Br+ncDNwP7A85SFBx/MF/yh9u+HF14ofh4RKS/BCt/MLgW2u/uaY2xXb2Y5M8s1NTWFilMS3Ht3n4hId4Tcwz8XmG1mm4FHgfPN7KFDN3L3Be6edfdsXV1dwDj939/8DQwe3PV9U6cWN4uIlJ9ghe/ut7r7qe4+FrgS+G93vyrUfOXgssvg0kvzpW8GmQzU1MDixVBdHTudiJS6YhylI900YAAsWgS/+lX+sMzaWrjiChg1KnYyESkHRSl8d18JrCzGXKXODM45J38TEelL+qStiEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIok4auGb2Ylm9vEuxieEiyQiIiEcsfDN7C+BN4ClZva6mX3mgLsXhg4mIiJ962h7+POAKe4+EbgG+Fcz+1LhPgueTERE+tTRPmlb4e7bANx9tZl9EVhmZqcBOnejiEiJOdoe/u4D1+8L5T8NmAN8OnAuERHpY0cr/OuBAWZ2xgcD7r4bmAl8LXQwERHpW0csfHf/tbv/DlhsZv9oeYOAHwHfKFpCERHpE905Dv9s4DTgReBl4P+Rv7iJiIiUkO4UfgfQAgwCMsAmd9c1akVESkx3Cv9l8oX/GeDzwFwzeyxoKhER6XPduQDKte6eK3y/DZhjZn8dMJOIiARwzD38A8r+wLF/DRNHRERC0cnTREQSEazwzSxjZqvN7NeFc/F8J9RcB1q/HvbrLWWRsvbaa6/x6KOPsm7duthRSkrIi5i3Aee7+x4zqwReMLPl7v6rUBPu2gVTpsBTT8H06aFmEZFYWlpamD17Ni+++CIVFRV0dnYyadIkli9fzpAhQ2LH6/eCFb67O7Cn8GNl4RbkHDxvvQWNjbBqFbS3w113QUcHjBoFkyaFmFFEYpg3bx4vvPACra2tH47lcjkaGhp44IEHIiYrDZbv5UAPblYBrAE+AfzE3f/xaNtns1nP5Q57j/iYvvtduO02qK6GtjYYNAhaWuC882Dlyl5FF5F+qLa2lvfff/+w8erqalpaWjBL70S+ZrbG3bPd2Tbom7bu3lk4vfKpwFlmduah25hZvZnlzCzX1NTUq3n++Z/hnnvgg9cuM/irv4Knnz6O8CLS7xy4Z3+gjo4O9uvNu2MqylE67r4TeJb8idcOvW+Bu2fdPVtXV9frOUaOzC/jDBiQ37s/4QSoqjqO0CLS71xwwQUMGHBwbZkZ5557LhUVFZFSlY6QR+nUmVlt4ftBwAzyV9AKYtUqmDoVNm+GG26AtWtDzSQisdx1113U1taSyWQAyGQynHjiidxzzz2Rk5WGYGv4heve/gyoIP/Cstjdv3u0f9PbNXwRScc777zDvffeSy6XY+LEidTX1zNq1KjYsaLpyRp+0Ddte0qFLyLSM/3mTVsREek/VPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFL5KQLVu2sHr1anbv3h07igDNzc2sXr2at956qyjzBSt8MzvNzJ41sw1m9rqZNYSaS0SObteuXcycOZNx48YxY8YMRo0axfe+973YsZJ25513MnLkSGbMmMH48eOZNm0aO3bsCDpnyD38fcC33P0M4LPAN83sjIDzicgRfOUrX2HlypW0traya9cuWlpa+P73v8/ixYtjR0vSihUr+Pa3v83evXvZtWsXra2tvPTSS3z5y18OOm+wwnf3be6+tvD9bqARGB1qPhHp2o4dO1ixYgVtbW0HjTc3N/ODH/wgUqq03XHHHezdu/egsfb2dl588UW2bNkSbN6irOGb2VhgErCqi/vqzSxnZrmmpqZixBFJynvvvcfAgQO7vG/79u1FTiMA27Zt63K8qqqKkD0YvPDN7ARgKXCju+869H53X+DuWXfP1tXVhY4jkpwxY8aQyWQOG6+oqOCCCy6IkEhmzpxJZWXlYePuzhlnhFv5Dlr4ZlZJvuwfdvd/CzmXiHRt4MCB3H333dTU1Hw4VllZydChQ7ntttsiJkvXzTffzEknnURVVdWHYzU1Nfzwhz+kuro62Lwhj9Ix4H6g0d1/FGoeETm2uXPn8swzzzBnzhwmTJjAN77xDV599VXGjBkTO1qSTj75ZF599VUaGhqYMGECl1xyCcuWLaO+vj7ovObuYR7YbCrwPPAasL8wPM/dnzrSv8lms57L5YLkEREpR2a2xt2z3dm263dy+oC7vwBYqMcXEZGe0SdtRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl+C6uzsZP/+/cfesJ9rbm7m3XffJdQ1oKXn9u3bFztCyQlW+Gb2gJltN7P1oeaQ/q+hoYFbb701doxe27FjB3PmzGHYsGGccsopjB8/nl/+8pexYyVv586dnHLKKWzbti12lJIScg9/ITAz4ONLP9bW1kZLSwsPP/wwCxcupLW1lba2ttixesTdueiii1i+fDnt7e20t7fz29/+losuuohNmzbFjpekzs5O2traeOKJJ2hqamLp0qW0tbXR2dkZO1pJCFb47v4/wI5Qjy/919q1a8lkMtTU1NDR0UFzczODBg0ik8nw+uuvx47XbevWraOxsZGOjo6Dxtvb25k/f36kVGmbO3cumUyGr33tawDceOONZDIZZs2aFTlZaYi+hm9m9WaWM7NcU1NT7DjSByZPnszPf/5zampqaG5uprm5mcGDB7No0SI+/elPx47XbZs3b6aiouKw8Y6ODt54440IiWT+/PlMnz6dTCYDQCaT4XOf+xwLFy6MG6xERC98d1/g7ll3z9bV1cWOI33kqquuoqamhqqqKiorK6mtreWKK66IHatHJk6cSHt7+2HjgwYNYurUqRESyYgRI6ivr6e1tZXBgwfT2trKtddey8knnxw7WkmIXvhSnhobG3n33XdpaGjg+uuvZ+vWrSW37v2xj32Myy67jJqamg/HKioqGDJkCNddd13EZGl74oknGDlyJPfeey+jR49m6dKlsSOVDAt5mJmZjQWWufuZ3dk+m816LpcLlkeKp62tjU2bNjF+/HgANmzYwLhx46isrIycrGf27dvHj3/8Y+bPn8+ePXuYNWsWt99+O6NHj44dLVkbN26krq6OIUOGsHfvXrZu3cq4ceNix4rGzNa4e7Zb24YqfDNbBEwDRgBvA7e5+/1H+zcqfBGRnulJ4Q8MFcLd54Z6bBER6Tmt4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgiVPgiIolQ4YuIJEKFLyKSCBW+iEgigha+mc00s9+Y2f+a2S0h5yoXjY2NXHjhhWQyGYYPH868efNob2+PHUtEykCwi5ibWQXwE2AGsAV42cyedPcNoeYsdVu3buWzn/0su3fvxt1pa2vjzjvv5He/+x2PPfZY7HgiUuJC7uGfBfyvu29093bgUWBOwPlK3t13301rayvu/uFYS0sLy5YtY9OmTRGTiUg5CFn4o4E/HPDzlsKYHMHq1au7XL6prq5mwwb9YSQixyf6m7ZmVm9mOTPLNTU1xY4T1cSJE6msrDxsvL29nU9+8pMREolIOQlZ+FuB0w74+dTC2EHcfYG7Z909W1dXFzBO/9fQ0EB1dfVBY5lMhmnTpjFu3LhIqUSkXIQs/JeBcWZ2uplVAVcCTwacr+SNGTOG5557jrPOOgszY9CgQVx99dUsWbIkdjQRKQPBjtJx931mdgPwH0AF8IC7vx5qvnIxefJkVq1aRWdnJwMGDMDMYkcSkTIRrPAB3P0p4KmQc5SrioqK2BFEpMxEf9NWRESKQ4UvIpIIFb6ISCJU+CIiiVDhi4gkwg48b0tsZtYEvNUHDzUCeKcPHic2PY/+Rc+j/ymX53I8z2OMu3frU6v9qvD7ipnl3D0bO8fx0vPoX/Q8+p9yeS7Feh5a0hERSYQKX0QkEeVa+AtiB+gjeh79i55H/1Muz6Uoz6Ms1/BFRORw5bqHLyIihyirwjezB8xsu5mtj52lt8zsNDN71sw2mNnrZtYQO1NvmVnGzFab2a8Lz+U7sTP1lplVmNk6M1sWO8vxMLPNZvaamb1iZrnYeXrLzGrNbImZvWFmjWZ2TuxMPWVmnyr8Hj647TKzG4POWU5LOmb2BWAP8HN3PzN2nt4ws48AH3H3tWY2BFgD/HkpXvzd8ud2Huzue8ysEngBaHD3X0WO1mNmdhOQBU5090tj5+ktM9sMZN29pI9dN7OfAc+7+32F623UuPvO2Ll6y8wqyF8g6mx374vPInWprPbw3f1/gB2xcxwPd9/m7msL3+8GGinRawF73p7Cj5WFW8ntYZjZqcAs4L7YWQTMbCjwBeB+AHdvL+WyL5gOvBmy7KHMCr/cmNlYYBKwKm6S3isshbwCbAeecfdSfC53AjcD+2MH6QMOPG1ma8ysPnaYXjodaAIeLCyz3Wdmg2OHOk5XAotCT6LC76fM7ARgKXCju++Knae33L3T3SeSv6bxWWZWUkttZnYpsN3d18TO0kemuvtk4GLgm4Vl0FIzEJgM3OPuk4Bm4Ja4kXqvsCQ1G3gs9Fwq/H6osN69FHjY3f8tdp6+UPiT+1lgZuwsPXQuMLuw9v0ocL6ZPRQ3Uu+5+9bC1+3A48BZcRP1yhZgywF/LS4h/wJQqi4G1rr726EnUuH3M4U3Ou8HGt39R7HzHA8zqzOz2sL3g4AZwBtxU/WMu9/q7qe6+1jyf3b/t7tfFTlWr5jZ4MKBABSWQC4ESu6INnf/I/AHM/tUYWg6UHIHNRxgLkVYzoHA17QtNjNbBEwDRpjZFuA2d78/bqoeOxf4a+C1wto3wLzC9YFLzUeAnxWOQBgALHb3kj6sscSNAh7P71MwEHjE3VfEjdRrfwc8XFgO2QhcEzlPrxReeGcAf1uU+crpsEwRETkyLemIiCRChS8ikggVvohIIlT4IiKJUOGLiCRChS/SDWa2wsx2lvrZMiVtKnyR7rmD/OcjREqWCl/kAGb2GTN7tXAu/8GF8/if6e7/BeyOnU/keJTVJ21Fjpe7v2xmTwLfAwYBD7l7yZ1+QKQrKnyRw30XeBloBf5P5CwifUZLOiKHGw6cAAwBMpGziPQZFb7I4X4K/F/gYeBfImcR6TNa0hE5gJl9Behw90cKZ/l80czOB74DjAdOKJyJ9Vp3/4+YWUV6SmfLFBFJhJZ0REQSocIXEUmECl9EJBEqfBGRRKjwRUQSocIXEUmECl9EJBEqfBGRRPx/8M2pXUIpanQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 분류 판단을 위한 임계치로 0.5를 설정하여 결과를 출력한다\n", "# Train 데이터로 학습한 모델이, 새로운 데이터에 대해서도 잘 적용됨을 볼 수 있다\n", "plt.scatter(X_train[:,0], X_train[:,1], \n", " marker = 'o',\n", " c = ['b'] * 5 + ['k'] * 5)\n", "\n", "colours = ['k' if prediction >= 0.5 else 'b' for prediction in predictions]\n", "plt.scatter(X_test[:,0], X_test[:,1], \n", " marker = '*',\n", " c = colours)\n", "plt.xlabel('x1'); plt.ylabel('x2'); plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **4 그레디언트 하강과 로지스틱 화귀를 이용한 CTR 예측**\n", "Click Through Rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **01 알고리즘에 1K개의 데이터로 학습**\n", "데이터 Set **앞 1,000개로** 학습, **뒤의 1,000개로** 테스트 한다" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'C1': '1005', 'C14': '15706', 'C15': '320', 'C16': '50', 'C17': '1722', 'C18': '0', 'C19': '35', 'C20': '-1', 'C21': '79', 'app_category': '07d7df22', 'app_domain': '7801e8d9', 'app_id': 'ecad2386', 'banner_pos': '0', 'device_conn_type': '2', 'device_model': '44956a24', 'device_type': '1', 'site_category': '28905ebd', 'site_domain': 'f3845767', 'site_id': '1fbe01fe'}\n", "{'C1': '1005', 'C14': '15704', 'C15': '320', 'C16': '50', 'C17': '1722', 'C18': '0', 'C19': '35', 'C20': '100084', 'C21': '79', 'app_category': '07d7df22', 'app_domain': '7801e8d9', 'app_id': 'ecad2386', 'banner_pos': '0', 'device_conn_type': '0', 'device_model': '711ee120', 'device_type': '1', 'site_category': '28905ebd', 'site_domain': 'f3845767', 'site_id': '1fbe01fe'}\n" ] } ], "source": [ "import csv\n", "def read_ad_click_data(n, offset=0):\n", " X_dict, y = [], []\n", " with open('./data/train.csv', 'r') as csvfile:\n", " reader = csv.DictReader(csvfile)\n", " for i in range(offset):\n", " next(reader)\n", " i = 0\n", " for row in reader:\n", " i += 1\n", " y.append(int(row['click']))\n", " del row['click'], row['id'], row['hour'], row['device_id'], row['device_ip']\n", " X_dict.append(dict(row))\n", " if i >= n: break\n", " return X_dict, y\n", "\n", "n = 1000\n", "X_dict_train, y_train = read_ad_click_data(n)\n", "print(X_dict_train[0])\n", "print(X_dict_train[1])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th Logistic Cost : 0.68107\n", "1,000th Logistic Cost : 0.41219\n", "2,000th Logistic Cost : 0.40069\n", "3,000th Logistic Cost : 0.39300\n", "4,000th Logistic Cost : 0.38696\n", "5,000th Logistic Cost : 0.38186\n", "6,000th Logistic Cost : 0.37740\n", "7,000th Logistic Cost : 0.37341\n", "8,000th Logistic Cost : 0.36979\n", "9,000th Logistic Cost : 0.36646\n", "--- 7.747s seconds ---\n" ] } ], "source": [ "# 데이터 학습을 위해 One-Hot-Encoding 객체로 임베딩\n", "from sklearn.feature_extraction import DictVectorizer\n", "dict_one_hot_encoder = DictVectorizer(sparse=False)\n", "X_train = dict_one_hot_encoder.fit_transform(X_dict_train)\n", "X_dict_test, y_test_1k = read_ad_click_data(n, n)\n", "X_test = dict_one_hot_encoder.transform(X_dict_test)\n", "\n", "X_train_1k = X_train\n", "y_train_1k = np.array(y_train)\n", "\n", "import timeit\n", "start_time = timeit.default_timer()\n", "weights = train_logistic_regression(X_train_1k, y_train_1k, max_iter=10000, learning_rate=0.01, fit_intercept=True)\n", "print(\"--- %0.3fs seconds ---\" % (timeit.default_timer() - start_time))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The ROC AUC on testing set is: 0.663\n" ] } ], "source": [ "# 위에서 학습한 모델의 정확도 측정\n", "X_test_1k = X_test\n", "predictions = predict(X_test_1k, weights)\n", "from sklearn.metrics import roc_auc_score\n", "print('The ROC AUC on testing set is: {0:.3f}'.format(roc_auc_score(y_test_1k, predictions)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **02 SGD 그래디언트 하강기법을 사용**\n", "1. **update_weights_sgd()** 함수를 사용\n", "1. **SGD 기법으로** 데이터 Set **앞 1,000개로** 학습, **뒤의 1,000개로** 테스트 한다" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# SGD를 이용한 로지스틱 회귀 알고리즘에 맞게 수정\n", "def update_weights_sgd(X_train, y_train, weights, learning_rate):\n", " for X_each, y_each in zip(X_train, y_train):\n", " prediction = compute_prediction(X_each, weights)\n", " weights_delta = X_each.T * (y_each - prediction)\n", " weights += learning_rate * weights_delta\n", " return weights" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# update_weights_sgd() 로 로지스틱 회귀 알고리즘을 적용\n", "def train_logistic_regression(X_train, y_train, max_iter, learning_rate, fit_intercept=False):\n", " if fit_intercept:\n", " intercept = np.ones((X_train.shape[0], 1))\n", " X_train = np.hstack((intercept, X_train))\n", " weights = np.zeros(X_train.shape[1])\n", " for iteration in range(max_iter):\n", " weights = update_weights_sgd(X_train, y_train, weights, learning_rate)\n", " # Check the cost for every 2 (for example) iterations\n", " if iteration % 2 == 0:\n", " print(\"{:,}th SGD Logistic : {:.5f}\".format(\n", " iteration, compute_cost(X_train, y_train, weights)))\n", " return weights" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th SGD Logistic : 0.41983\n", "2th SGD Logistic : 0.40212\n", "4th SGD Logistic : 0.39185\n", "--- 0.155s seconds ---\n" ] } ], "source": [ "# 1K 샘플 데이터로 SGD model \n", "start_time = timeit.default_timer()\n", "weights = train_logistic_regression(X_train_1k, y_train_1k, max_iter=5, learning_rate=0.01, fit_intercept=True)\n", "print(\"--- %0.3fs seconds ---\" % (timeit.default_timer() - start_time))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The ROC AUC on testing set is: 0.672\n" ] } ], "source": [ "predictions = predict(X_test_1k, weights)\n", "print('The ROC AUC on testing set is: {0:.3f}'.format(roc_auc_score(y_test_1k, predictions)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **03 SGD 알고리즘에 10K개의 데이터로 학습**\n", "1. 데이터 Set **앞 10,000개로** 학습, **뒤의 10,000개로** 테스트 한다\n", "1. 훨씬 속도도 빠르고 모델의 결과도 좋게 출력된다" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th SGD Logistic : 0.41497\n", "2th SGD Logistic : 0.40601\n", "4th SGD Logistic : 0.40105\n", "--- 0.947s seconds ---\n" ] } ], "source": [ "n = 10000\n", "X_dict_train, y_train = read_ad_click_data(n)\n", "dict_one_hot_encoder = DictVectorizer(sparse=False)\n", "X_train = dict_one_hot_encoder.fit_transform(X_dict_train)\n", "\n", "X_train_10k = X_train\n", "y_train_10k = np.array(y_train)\n", "\n", "# Train the SGD model based on 100000 samples\n", "start_time = timeit.default_timer()\n", "weights = train_logistic_regression(X_train_10k, y_train_10k, max_iter=5, learning_rate=0.01, fit_intercept=True)\n", "print(\"--- %0.3fs seconds ---\" % (timeit.default_timer() - start_time))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The ROC AUC on testing set is: 0.720\n" ] } ], "source": [ "X_dict_test, y_test_10k = read_ad_click_data(10000, 10000)\n", "X_test_10k = dict_one_hot_encoder.transform(X_dict_test)\n", "\n", "predictions = predict(X_test_10k, weights)\n", "print('The ROC AUC on testing set is: {0:.3f}'.format(roc_auc_score(y_test_10k, predictions)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **5 Sklearn 을 활용한 SGD 알고리즘으로 CTR 예측**\n", "scikit-learn 모듈의 활용" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The ROC AUC on testing set is: 0.721\n" ] } ], "source": [ "# Use scikit-learn package\n", "from sklearn.linear_model import SGDClassifier\n", "sgd_lr = SGDClassifier(loss='log', penalty=None, \n", " fit_intercept=True, max_iter=5, \n", " learning_rate='constant', eta0=0.01)\n", "sgd_lr.fit(X_train_10k, y_train_10k)\n", "\n", "\n", "predictions = sgd_lr.predict_proba(X_test_10k)[:, 1]\n", "print('The ROC AUC on testing set is: {0:.3f}'.format(roc_auc_score(y_test_10k, predictions)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **6 정규화 기법을 이용한 SGD**\n", "**L1 정규화 기법을** 활용한 로지스틱 회귀모델 **Feature Selection**" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n", " eta0=0.01, fit_intercept=True, l1_ratio=0.15,\n", " learning_rate='constant', loss='log', max_iter=5, n_iter=None,\n", " n_jobs=1, penalty='l1', power_t=0.5, random_state=None,\n", " shuffle=True, tol=None, verbose=0, warm_start=False)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l1_feature_selector = SGDClassifier(loss = 'log', penalty = 'l1', \n", " alpha = 0.0001, fit_intercept = True, \n", " max_iter = 5, learning_rate = 'constant', \n", " eta0 = 0.01)\n", "l1_feature_selector.fit(X_train_10k, y_train_10k)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/markbaum/Python/python/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n", " from numpy.core.umath_tests import inner1d\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "원본 데이터 Set : (10000, 2820)\n", "Ramdom Forest 로 특징 feature 선별 : (10000, 500)\n" ] } ], "source": [ "# 중요 Feature 정규화 cf) transform 은 작동하지 않는다 (Coding Error)\n", "# X_train_10k_selected = l1_feature_selector.transform(X_train_10k)\n", "print(\"원본 데이터 Set : \", X_train_10k.shape)\n", "print(\"Ramdom Forest 로 특징 feature 선별 : \", X_train_10k_selected.shape)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.5962561 -0.44022485 -0.42428472 -0.42428472 -0.41595815 -0.41548047\n", " -0.31676318 -0.30903059 -0.30744771 -0.28089655]\n", "[ 559 2172 2566 2370 1540 34 579 2116 278 577]\n" ] } ], "source": [ "# 하위 10 개의 weights 그리고 the corresponding 10 least important features\n", "print(np.sort(l1_feature_selector.coef_)[0][:10])\n", "print(np.argsort(l1_feature_selector.coef_)[0][:10])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.28423705 0.2842371 0.29318359 0.29969314 0.31062841 0.34092667\n", " 0.34649048 0.34906087 0.36057499 0.40919723]\n", "[2769 363 546 2275 547 2149 1503 2580 1519 2761]\n" ] } ], "source": [ "# 상위 10 개의 weights and the corresponding 10 most important features\n", "print(np.sort(l1_feature_selector.coef_)[0][-10:])\n", "print(np.argsort(l1_feature_selector.coef_)[0][-10:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **7 온라인 러닝 대규모 데이터세트 학습**\n", "실시간 데이터는 **청크 데이터**로 **작은 크기로 전처리 작업을** 수행한다" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- 2.423s seconds ---\n" ] } ], "source": [ "# The number of iterations is set to 1 if using partial_fit.\n", "sgd_lr = SGDClassifier(loss='log', penalty=None, fit_intercept=True, max_iter=1, learning_rate='constant', eta0=0.01)\n", "\n", "import timeit\n", "start_time = timeit.default_timer()\n", "\n", "# there are 40428968 labelled samples, use the first ten 100k samples for training, and the next 100k for testing\n", "for i in range(20):\n", " X_dict_train, y_train_every_100k = read_ad_click_data(1000, i * 1000)\n", " X_train_every_100k = dict_one_hot_encoder.transform(X_dict_train)\n", " sgd_lr.partial_fit(X_train_every_100k, y_train_every_100k, classes=[0, 1])\n", "\n", "print(\"--- %0.3fs seconds ---\" % (timeit.default_timer() - start_time))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The ROC AUC on testing set is: 0.694\n" ] } ], "source": [ "X_dict_test, y_test_next10k = read_ad_click_data(1000, (i + 1) * 1000)\n", "X_test_next10k = dict_one_hot_encoder.transform(X_dict_test)\n", "predictions = sgd_lr.predict_proba(X_test_next10k)[:, 1]\n", "print('The ROC AUC on testing set is: {0:.3f}'.format(roc_auc_score(y_test_next10k, predictions)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **8 다중 클래스의 분류처리**\n", "전체 **20개의 카테고리로** 분류된 텍스트를 **SGD 를 활용한** 모델링" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# 뉴스그룹 데이터를 호출하기 & 텍스트 전처리를 작업한다\n", "from nltk.corpus import names\n", "from nltk.stem import WordNetLemmatizer\n", "all_names = set(names.words())\n", "lemmatizer = WordNetLemmatizer()\n", "\n", "def letters_only(astr):\n", " for c in astr:\n", " if not c.isalpha(): \n", " return False\n", " return True\n", "\n", "def clean_text(docs):\n", " cleaned_docs = []\n", " for doc in docs:\n", " cleaned_docs.append(' '.join([lemmatizer.lemmatize(word.lower())\n", " for word in doc.split()\n", " if letters_only(word)\n", " and word not in all_names]))\n", " return cleaned_docs" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# 뉴스그룹 데이터를 불러온다\n", "from sklearn.datasets import fetch_20newsgroups\n", "data_train = fetch_20newsgroups(subset='train', categories=None, random_state=42)\n", "data_test = fetch_20newsgroups(subset='test', categories=None, random_state=42)\n", "\n", "# 텍스트를 전처리\n", "cleaned_train = clean_text(data_train.data)\n", "cleaned_test = clean_text(data_test.data)\n", "# 라벨링 처리\n", "label_train = data_train.target\n", "label_test = data_test.target" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# 전처리 작업된 텍스트를 Tf-IDF로 변환\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "tfidf_vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5, stop_words='english', max_features=40000)\n", "term_docs_train = tfidf_vectorizer.fit_transform(cleaned_train)\n", "term_docs_test = tfidf_vectorizer.transform(cleaned_test)\n", "\n", "# grid search 검색을 적용\n", "from sklearn.model_selection import GridSearchCV\n", "parameters = {'penalty': ['l2', None],\n", " 'alpha' : [1e-07, 1e-06, 1e-05, 1e-04],\n", " 'eta0' : [0.01, 0.1, 1, 10]}" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'alpha': 1e-07, 'eta0': 10, 'penalty': 'l2'}\n" ] } ], "source": [ "# SGD 분류기를 활용하여 예측모델을 생성한다\n", "sgd_lr = SGDClassifier(loss='log', learning_rate='constant', eta0=0.01, fit_intercept=True, max_iter=10)\n", "grid_search = GridSearchCV(sgd_lr, parameters, n_jobs=-1, cv=3)\n", "\n", "grid_search.fit(term_docs_train, label_train)\n", "print(grid_search.best_params_)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The accuracy on testing set is: 79.6%\n" ] }, { "data": { "text/plain": [ "SGDClassifier(alpha=1e-07, average=False, class_weight=None, epsilon=0.1,\n", " eta0=10, fit_intercept=True, l1_ratio=0.15,\n", " learning_rate='constant', loss='log', max_iter=10, n_iter=None,\n", " n_jobs=1, penalty='l2', power_t=0.5, random_state=None,\n", " shuffle=True, tol=None, verbose=0, warm_start=False)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy = sgd_lr_best.score(term_docs_test, label_test)\n", "print('The accuracy on testing set is: {0:.1f}%'.format(accuracy*100))\n", "grid_search.best_estimator_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "# **9 Feature Selection 과 Random Forest 비교**\n", "1. **feature_importance_** : feature의 중요도를 출력" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# Random Foreset 로 중요도 높은 500개의 데이터를 추출한다\n", "from sklearn.ensemble import RandomForestClassifier\n", "random_forest = RandomForestClassifier(n_estimators = 100, \n", " criterion = 'gini', \n", " min_samples_split = 30, \n", " n_jobs = -1)\n", "random_forest.fit(X_train_10k, y_train_10k)\n", "\n", "top500_feature = np.argsort(random_forest.feature_importances_)[-500:]\n", "X_train_10k_selected = X_train_10k[:, top500_feature]" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", "[2040 2764 2280 1896 1001 1454 756 135 2676 764]\n" ] } ], "source": [ "# 상관성 중요도 하위 10 위 가중치 출력\n", "print(np.sort(random_forest.feature_importances_)[:10])\n", "print(np.argsort(random_forest.feature_importances_)[:10])" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.00755481 0.00772242 0.00798538 0.00818412 0.00886733 0.00905481\n", " 0.00942318 0.00986043 0.01424382 0.01465488]\n", "[2307 549 1284 1503 1540 1923 1085 314 554 393]\n" ] } ], "source": [ "# 상관성 중요도 상위 10 위 가중치 출력 (중요도 클수록 나중에 출력)\n", "print(np.sort(random_forest.feature_importances_)[-10:])\n", "print(np.argsort(random_forest.feature_importances_)[-10:])" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "C18=2\n" ] } ], "source": [ "# 393번째 학습모델이 상위\n", "print(dict_one_hot_encoder.feature_names_[393])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10000, 500)\n" ] } ], "source": [ "# 상위 500개의 feature를 선택 출력한다\n", "top500_feature = np.argsort(random_forest.feature_importances_)[-500:]\n", "X_train_10k_selected = X_train_10k[:, top500_feature]\n", "print(X_train_10k_selected.shape)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }