{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 퇴직여부 분류기 생성\n", "\n", "- IBM에서 제공했던 HR 데이터를 활용\n", "- 결정트리를 이용하여 어떤 사람이 퇴직을 여부를 분류할 수 있는 분류기 생성" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "pd.options.display.max_columns=None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 데이터 로딩" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "IBM kaggle 데이터 : https://www.kaggle.com/pavansubhasht/ibm-hr-analytics-attrition-dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeAttritionBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEmployeeCountEmployeeNumberEnvironmentSatisfactionGenderHourlyRateJobInvolvementJobLevelJobRoleJobSatisfactionMaritalStatusMonthlyIncomeMonthlyRateNumCompaniesWorkedOver18OverTimePercentSalaryHikePerformanceRatingRelationshipSatisfactionStandardHoursStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
041YesTravel_Rarely1102Sales12Life Sciences112Female9432Sales Executive4Single5993194798YYes11318008016405
149NoTravel_Frequently279Research & Development81Life Sciences123Male6122Research Scientist2Married5130249071YNo2344801103310717
237YesTravel_Rarely1373Research & Development22Other144Male9221Laboratory Technician3Single209023966YYes15328007330000
333NoTravel_Frequently1392Research & Development34Life Sciences154Female5631Research Scientist3Married2909231591YYes11338008338730
427NoTravel_Rarely591Research & Development21Medical171Male4031Laboratory Technician2Married3468166329YNo12348016332222
\n", "
" ], "text/plain": [ " Age Attrition BusinessTravel DailyRate Department \\\n", "0 41 Yes Travel_Rarely 1102 Sales \n", "1 49 No Travel_Frequently 279 Research & Development \n", "2 37 Yes Travel_Rarely 1373 Research & Development \n", "3 33 No Travel_Frequently 1392 Research & Development \n", "4 27 No Travel_Rarely 591 Research & Development \n", "\n", " DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \\\n", "0 1 2 Life Sciences 1 1 \n", "1 8 1 Life Sciences 1 2 \n", "2 2 2 Other 1 4 \n", "3 3 4 Life Sciences 1 5 \n", "4 2 1 Medical 1 7 \n", "\n", " EnvironmentSatisfaction Gender HourlyRate JobInvolvement JobLevel \\\n", "0 2 Female 94 3 2 \n", "1 3 Male 61 2 2 \n", "2 4 Male 92 2 1 \n", "3 4 Female 56 3 1 \n", "4 1 Male 40 3 1 \n", "\n", " JobRole JobSatisfaction MaritalStatus MonthlyIncome \\\n", "0 Sales Executive 4 Single 5993 \n", "1 Research Scientist 2 Married 5130 \n", "2 Laboratory Technician 3 Single 2090 \n", "3 Research Scientist 3 Married 2909 \n", "4 Laboratory Technician 2 Married 3468 \n", "\n", " MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike \\\n", "0 19479 8 Y Yes 11 \n", "1 24907 1 Y No 23 \n", "2 2396 6 Y Yes 15 \n", "3 23159 1 Y Yes 11 \n", "4 16632 9 Y No 12 \n", "\n", " PerformanceRating RelationshipSatisfaction StandardHours \\\n", "0 3 1 80 \n", "1 4 4 80 \n", "2 3 2 80 \n", "3 3 3 80 \n", "4 3 4 80 \n", "\n", " StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n", "0 0 8 0 \n", "1 1 10 3 \n", "2 0 7 3 \n", "3 0 8 3 \n", "4 1 6 3 \n", "\n", " WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 1 6 4 \n", "1 3 10 7 \n", "2 3 0 0 \n", "3 3 8 7 \n", "4 3 2 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "0 0 5 \n", "1 1 7 \n", "2 0 0 \n", "3 3 0 \n", "4 2 2 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets = pd.read_csv('./inputs/HR-Employee-Attrition.csv')\n", "datasets.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1470, 35)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "데이터를 살펴보면 categorycal features, numerical features가 함께 있습니다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Age int64\n", "Attrition object\n", "BusinessTravel object\n", "DailyRate int64\n", "Department object\n", "DistanceFromHome int64\n", "Education int64\n", "EducationField object\n", "EmployeeCount int64\n", "EmployeeNumber int64\n", "EnvironmentSatisfaction int64\n", "Gender object\n", "HourlyRate int64\n", "JobInvolvement int64\n", "JobLevel int64\n", "JobRole object\n", "JobSatisfaction int64\n", "MaritalStatus object\n", "MonthlyIncome int64\n", "MonthlyRate int64\n", "NumCompaniesWorked int64\n", "Over18 object\n", "OverTime object\n", "PercentSalaryHike int64\n", "PerformanceRating int64\n", "RelationshipSatisfaction int64\n", "StandardHours int64\n", "StockOptionLevel int64\n", "TotalWorkingYears int64\n", "TrainingTimesLastYear int64\n", "WorkLifeBalance int64\n", "YearsAtCompany int64\n", "YearsInCurrentRole int64\n", "YearsSinceLastPromotion int64\n", "YearsWithCurrManager int64\n", "dtype: object" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets.dtypes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Taget variable : *Attrition*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yes / No -> 1 / 0으로 변경합니다\n", "- 1 : 퇴직 Yes\n", "- 0 : 퇴직 No" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeAttritionBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEmployeeCountEmployeeNumberEnvironmentSatisfactionGenderHourlyRateJobInvolvementJobLevelJobRoleJobSatisfactionMaritalStatusMonthlyIncomeMonthlyRateNumCompaniesWorkedOver18OverTimePercentSalaryHikePerformanceRatingRelationshipSatisfactionStandardHoursStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManagerAttrition_idx
041YesTravel_Rarely1102Sales12Life Sciences112Female9432Sales Executive4Single5993194798YYes113180080164051
149NoTravel_Frequently279Research & Development81Life Sciences123Male6122Research Scientist2Married5130249071YNo23448011033107170
237YesTravel_Rarely1373Research & Development22Other144Male9221Laboratory Technician3Single209023966YYes153280073300001
333NoTravel_Frequently1392Research & Development34Life Sciences154Female5631Research Scientist3Married2909231591YYes113380083387300
427NoTravel_Rarely591Research & Development21Medical171Male4031Laboratory Technician2Married3468166329YNo123480163322220
\n", "
" ], "text/plain": [ " Age Attrition BusinessTravel DailyRate Department \\\n", "0 41 Yes Travel_Rarely 1102 Sales \n", "1 49 No Travel_Frequently 279 Research & Development \n", "2 37 Yes Travel_Rarely 1373 Research & Development \n", "3 33 No Travel_Frequently 1392 Research & Development \n", "4 27 No Travel_Rarely 591 Research & Development \n", "\n", " DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \\\n", "0 1 2 Life Sciences 1 1 \n", "1 8 1 Life Sciences 1 2 \n", "2 2 2 Other 1 4 \n", "3 3 4 Life Sciences 1 5 \n", "4 2 1 Medical 1 7 \n", "\n", " EnvironmentSatisfaction Gender HourlyRate JobInvolvement JobLevel \\\n", "0 2 Female 94 3 2 \n", "1 3 Male 61 2 2 \n", "2 4 Male 92 2 1 \n", "3 4 Female 56 3 1 \n", "4 1 Male 40 3 1 \n", "\n", " JobRole JobSatisfaction MaritalStatus MonthlyIncome \\\n", "0 Sales Executive 4 Single 5993 \n", "1 Research Scientist 2 Married 5130 \n", "2 Laboratory Technician 3 Single 2090 \n", "3 Research Scientist 3 Married 2909 \n", "4 Laboratory Technician 2 Married 3468 \n", "\n", " MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike \\\n", "0 19479 8 Y Yes 11 \n", "1 24907 1 Y No 23 \n", "2 2396 6 Y Yes 15 \n", "3 23159 1 Y Yes 11 \n", "4 16632 9 Y No 12 \n", "\n", " PerformanceRating RelationshipSatisfaction StandardHours \\\n", "0 3 1 80 \n", "1 4 4 80 \n", "2 3 2 80 \n", "3 3 3 80 \n", "4 3 4 80 \n", "\n", " StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n", "0 0 8 0 \n", "1 1 10 3 \n", "2 0 7 3 \n", "3 0 8 3 \n", "4 1 6 3 \n", "\n", " WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 1 6 4 \n", "1 3 10 7 \n", "2 3 0 0 \n", "3 3 8 7 \n", "4 3 2 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager Attrition_idx \n", "0 0 5 1 \n", "1 1 7 0 \n", "2 0 0 1 \n", "3 3 0 0 \n", "4 2 2 0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets['Attrition_idx'] = datasets['Attrition']\\\n", " .apply(lambda x: 1 if x == 'Yes' else 0)\n", "datasets.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Column 전처리" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',\n", " 'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',\n", " 'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',\n", " 'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',\n", " 'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',\n", " 'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',\n", " 'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',\n", " 'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',\n", " 'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',\n", " 'YearsWithCurrManager', 'Attrition_idx'],\n", " dtype='object')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "col_names = datasets.columns\n", "col_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "필요없는 변수들이 있다 : *EmployeeCount*, *EmployeeNumber*, *Over18*, *StandardHours*" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Y 1470\n", "Name: Over18, dtype: int64\n", "1 1470\n", "Name: EmployeeCount, dtype: int64\n", "80 1470\n", "Name: StandardHours, dtype: int64\n" ] } ], "source": [ "print(datasets.Over18.value_counts())\n", "print(datasets.EmployeeCount.value_counts())\n", "print(datasets.StandardHours.value_counts())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Target은 feature에서 제외한다.\n", "col_names = col_names\\\n", " .drop(['Attrition_idx', 'Attrition', 'Over18', \n", " 'EmployeeCount', 'EmployeeNumber', 'StandardHours'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical column을 다루어보자.\n", "\n", "Catagorical column을 numerical column을 나누어보자." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "categorical_features = []\n", "numerical_features = []\n", "target = 'Attrition_idx'\n", "\n", "# feature를 2가지 형태로 구분한다.\n", "for col in col_names:\n", " if datasets[col].dtype == 'O':\n", " categorical_features.append(col)\n", " else:\n", " numerical_features.append(col)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Categorical feature의 수 : 7\n", "Numerical feature의 수 : 23\n" ] } ], "source": [ "print('Categorical feature의 수 :', len(categorical_features))\n", "print('Numerical feature의 수 :', len(numerical_features))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['BusinessTravel',\n", " 'Department',\n", " 'EducationField',\n", " 'Gender',\n", " 'JobRole',\n", " 'MaritalStatus',\n", " 'OverTime']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_features" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Age',\n", " 'DailyRate',\n", " 'DistanceFromHome',\n", " 'Education',\n", " 'EnvironmentSatisfaction',\n", " 'HourlyRate',\n", " 'JobInvolvement',\n", " 'JobLevel',\n", " 'JobSatisfaction',\n", " 'MonthlyIncome',\n", " 'MonthlyRate',\n", " 'NumCompaniesWorked',\n", " 'PercentSalaryHike',\n", " 'PerformanceRating',\n", " 'RelationshipSatisfaction',\n", " 'StockOptionLevel',\n", " 'TotalWorkingYears',\n", " 'TrainingTimesLastYear',\n", " 'WorkLifeBalance',\n", " 'YearsAtCompany',\n", " 'YearsInCurrentRole',\n", " 'YearsSinceLastPromotion',\n", " 'YearsWithCurrManager']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numerical_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical 데이터를 one-hot vector로 변경하자. Pandas에서 `get_dummies`를 이용하자.\n", "- Train, test set을 구분하지 않고 원핫벡터를 만드는 경우 : 해당 feature의 모든 원소들을 아는 경우, 예를 들어 회사 부서, 국가 코드의 경우에 해당한다.\n", "- Train, test set을 구분하고 train set으로만 원핫벡터를 만드는 경우 : 해당 feature의 원소가 test set에 없는 경우가 존재할 수 있다." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
BusinessTravel_Non-TravelBusinessTravel_Travel_FrequentlyBusinessTravel_Travel_RarelyDepartment_Human ResourcesDepartment_Research & DevelopmentDepartment_SalesEducationField_Human ResourcesEducationField_Life SciencesEducationField_MarketingEducationField_MedicalEducationField_OtherEducationField_Technical DegreeGender_FemaleGender_MaleJobRole_Healthcare RepresentativeJobRole_Human ResourcesJobRole_Laboratory TechnicianJobRole_ManagerJobRole_Manufacturing DirectorJobRole_Research DirectorJobRole_Research ScientistJobRole_Sales ExecutiveJobRole_Sales RepresentativeMaritalStatus_DivorcedMaritalStatus_MarriedMaritalStatus_SingleOverTime_NoOverTime_Yes
00010010100001000000001000101
10100100100000100000010001010
20010100000100100100000000101
30100100100001000000010001001
40010100001000100100000001010
\n", "
" ], "text/plain": [ " BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n", "0 0 0 \n", "1 0 1 \n", "2 0 0 \n", "3 0 1 \n", "4 0 0 \n", "\n", " BusinessTravel_Travel_Rarely Department_Human Resources \\\n", "0 1 0 \n", "1 0 0 \n", "2 1 0 \n", "3 0 0 \n", "4 1 0 \n", "\n", " Department_Research & Development Department_Sales \\\n", "0 0 1 \n", "1 1 0 \n", "2 1 0 \n", "3 1 0 \n", "4 1 0 \n", "\n", " EducationField_Human Resources EducationField_Life Sciences \\\n", "0 0 1 \n", "1 0 1 \n", "2 0 0 \n", "3 0 1 \n", "4 0 0 \n", "\n", " EducationField_Marketing EducationField_Medical EducationField_Other \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 1 0 \n", "\n", " EducationField_Technical Degree Gender_Female Gender_Male \\\n", "0 0 1 0 \n", "1 0 0 1 \n", "2 0 0 1 \n", "3 0 1 0 \n", "4 0 0 1 \n", "\n", " JobRole_Healthcare Representative JobRole_Human Resources \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " JobRole_Laboratory Technician JobRole_Manager \\\n", "0 0 0 \n", "1 0 0 \n", "2 1 0 \n", "3 0 0 \n", "4 1 0 \n", "\n", " JobRole_Manufacturing Director JobRole_Research Director \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " JobRole_Research Scientist JobRole_Sales Executive \\\n", "0 0 1 \n", "1 1 0 \n", "2 0 0 \n", "3 1 0 \n", "4 0 0 \n", "\n", " JobRole_Sales Representative MaritalStatus_Divorced \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " MaritalStatus_Married MaritalStatus_Single OverTime_No OverTime_Yes \n", "0 0 1 0 1 \n", "1 1 0 1 0 \n", "2 0 1 0 1 \n", "3 1 0 0 1 \n", "4 1 0 1 0 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_datasets = pd.get_dummies(datasets[categorical_features])\n", "categorical_datasets.head()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeDailyRateDistanceFromHomeEducationEnvironmentSatisfactionHourlyRateJobInvolvementJobLevelJobSatisfactionMonthlyIncomeMonthlyRateNumCompaniesWorkedPercentSalaryHikePerformanceRatingRelationshipSatisfactionStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
0411102122943245993194798113108016405
14927981361222513024907123441103310717
237137322492213209023966153207330000
3331392344563132909231591113308338730
427591211403123468166329123416332222
\n", "
" ], "text/plain": [ " Age DailyRate DistanceFromHome Education EnvironmentSatisfaction \\\n", "0 41 1102 1 2 2 \n", "1 49 279 8 1 3 \n", "2 37 1373 2 2 4 \n", "3 33 1392 3 4 4 \n", "4 27 591 2 1 1 \n", "\n", " HourlyRate JobInvolvement JobLevel JobSatisfaction MonthlyIncome \\\n", "0 94 3 2 4 5993 \n", "1 61 2 2 2 5130 \n", "2 92 2 1 3 2090 \n", "3 56 3 1 3 2909 \n", "4 40 3 1 2 3468 \n", "\n", " MonthlyRate NumCompaniesWorked PercentSalaryHike PerformanceRating \\\n", "0 19479 8 11 3 \n", "1 24907 1 23 4 \n", "2 2396 6 15 3 \n", "3 23159 1 11 3 \n", "4 16632 9 12 3 \n", "\n", " RelationshipSatisfaction StockOptionLevel TotalWorkingYears \\\n", "0 1 0 8 \n", "1 4 1 10 \n", "2 2 0 7 \n", "3 3 0 8 \n", "4 4 1 6 \n", "\n", " TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 0 1 6 4 \n", "1 3 3 10 7 \n", "2 3 3 0 0 \n", "3 3 3 8 7 \n", "4 3 3 2 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "0 0 5 \n", "1 1 7 \n", "2 0 0 \n", "3 3 0 \n", "4 2 2 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numerical_datasets = datasets[numerical_features]\n", "numerical_datasets.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical dataset과 numerical dataset을 합친다. 모델의 input으로 사용할 feature이다." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
BusinessTravel_Non-TravelBusinessTravel_Travel_FrequentlyBusinessTravel_Travel_RarelyDepartment_Human ResourcesDepartment_Research & DevelopmentDepartment_SalesEducationField_Human ResourcesEducationField_Life SciencesEducationField_MarketingEducationField_MedicalEducationField_OtherEducationField_Technical DegreeGender_FemaleGender_MaleJobRole_Healthcare RepresentativeJobRole_Human ResourcesJobRole_Laboratory TechnicianJobRole_ManagerJobRole_Manufacturing DirectorJobRole_Research DirectorJobRole_Research ScientistJobRole_Sales ExecutiveJobRole_Sales RepresentativeMaritalStatus_DivorcedMaritalStatus_MarriedMaritalStatus_SingleOverTime_NoOverTime_YesAgeDailyRateDistanceFromHomeEducationEnvironmentSatisfactionHourlyRateJobInvolvementJobLevelJobSatisfactionMonthlyIncomeMonthlyRateNumCompaniesWorkedPercentSalaryHikePerformanceRatingRelationshipSatisfactionStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
00010010100001000000001000101411102122943245993194798113108016405
101001001000001000000100010104927981361222513024907123441103310717
2001010000010010010000000010137137322492213209023966153207330000
30100100100001000000010001001331392344563132909231591113308338730
4001010000100010010000000101027591211403123468166329123416332222
\n", "
" ], "text/plain": [ " BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n", "0 0 0 \n", "1 0 1 \n", "2 0 0 \n", "3 0 1 \n", "4 0 0 \n", "\n", " BusinessTravel_Travel_Rarely Department_Human Resources \\\n", "0 1 0 \n", "1 0 0 \n", "2 1 0 \n", "3 0 0 \n", "4 1 0 \n", "\n", " Department_Research & Development Department_Sales \\\n", "0 0 1 \n", "1 1 0 \n", "2 1 0 \n", "3 1 0 \n", "4 1 0 \n", "\n", " EducationField_Human Resources EducationField_Life Sciences \\\n", "0 0 1 \n", "1 0 1 \n", "2 0 0 \n", "3 0 1 \n", "4 0 0 \n", "\n", " EducationField_Marketing EducationField_Medical EducationField_Other \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 1 0 \n", "\n", " EducationField_Technical Degree Gender_Female Gender_Male \\\n", "0 0 1 0 \n", "1 0 0 1 \n", "2 0 0 1 \n", "3 0 1 0 \n", "4 0 0 1 \n", "\n", " JobRole_Healthcare Representative JobRole_Human Resources \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " JobRole_Laboratory Technician JobRole_Manager \\\n", "0 0 0 \n", "1 0 0 \n", "2 1 0 \n", "3 0 0 \n", "4 1 0 \n", "\n", " JobRole_Manufacturing Director JobRole_Research Director \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " JobRole_Research Scientist JobRole_Sales Executive \\\n", "0 0 1 \n", "1 1 0 \n", "2 0 0 \n", "3 1 0 \n", "4 0 0 \n", "\n", " JobRole_Sales Representative MaritalStatus_Divorced \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " MaritalStatus_Married MaritalStatus_Single OverTime_No OverTime_Yes \\\n", "0 0 1 0 1 \n", "1 1 0 1 0 \n", "2 0 1 0 1 \n", "3 1 0 0 1 \n", "4 1 0 1 0 \n", "\n", " Age DailyRate DistanceFromHome Education EnvironmentSatisfaction \\\n", "0 41 1102 1 2 2 \n", "1 49 279 8 1 3 \n", "2 37 1373 2 2 4 \n", "3 33 1392 3 4 4 \n", "4 27 591 2 1 1 \n", "\n", " HourlyRate JobInvolvement JobLevel JobSatisfaction MonthlyIncome \\\n", "0 94 3 2 4 5993 \n", "1 61 2 2 2 5130 \n", "2 92 2 1 3 2090 \n", "3 56 3 1 3 2909 \n", "4 40 3 1 2 3468 \n", "\n", " MonthlyRate NumCompaniesWorked PercentSalaryHike PerformanceRating \\\n", "0 19479 8 11 3 \n", "1 24907 1 23 4 \n", "2 2396 6 15 3 \n", "3 23159 1 11 3 \n", "4 16632 9 12 3 \n", "\n", " RelationshipSatisfaction StockOptionLevel TotalWorkingYears \\\n", "0 1 0 8 \n", "1 4 1 10 \n", "2 2 0 7 \n", "3 3 0 8 \n", "4 4 1 6 \n", "\n", " TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 0 1 6 4 \n", "1 3 3 10 7 \n", "2 3 3 0 0 \n", "3 3 3 8 7 \n", "4 3 3 2 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "0 0 5 \n", "1 1 7 \n", "2 0 0 \n", "3 3 0 \n", "4 2 2 " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = pd.concat([categorical_datasets, numerical_datasets], axis=1)\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 1\n", "1 0\n", "2 1\n", "3 0\n", "4 0\n", "Name: Attrition_idx, dtype: int64" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = datasets[target]\n", "y.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Train set과 test set을 구분" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "x_train, x_test, y_train, y_test = \\\n", " train_test_split(X, y, test_size=0.2, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 모델 학습 및 hyperparameter 찾기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Decision tree를 이용하여 분류기를 생성합니다. 아래의 파라미터를 grid search로 찾습니다.\n", " - 트리의 최대깊이\n", " - 분할을 위한 최소 관측값\n", " - 각 단말 도느에서 필요한 최소 관측 값" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 3 folds for each of 12 candidates, totalling 36 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 36 out of 36 | elapsed: 0.2s finished\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=3, error_score='raise',\n", " estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=42,\n", " splitter='best'),\n", " fit_params=None, iid=True, n_jobs=-1,\n", " param_grid={'max_depth': [5, 7, 9], 'min_samples_split': [2], 'min_samples_leaf': [1, 2, 3, 4]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n", " scoring=None, verbose=1)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = {\n", " 'max_depth': [5,7,9],\n", " 'min_samples_split': [2], \n", " 'min_samples_leaf': [1, 2, 3, 4]\n", "}\n", "\n", "grid_search_cv = \\\n", " GridSearchCV(\n", " DecisionTreeClassifier(random_state=42), \n", " params, \n", " n_jobs=-1, \n", " verbose=1, \n", " cv=3)\n", "\n", "grid_search_cv.fit(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=42,\n", " splitter='best')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 검색 결과, 좋은 결과값을 주는 파라미터를 가진 tree 모델을 찾습니다.\n", "tree_classifier = grid_search_cv.best_estimator_\n", "tree_classifier" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8273809523809523" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Score를 살펴봅니다.\n", "grid_search_cv.best_score_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 모델 성능 측정" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "pred_train = tree_classifier.predict(x_train)\n", "pred_test = tree_classifier.predict(x_test)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score, classification_report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "학습셋에 대한 결과값을 살펴보자." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Train Confusion Matrix :\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
095127
198100
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 951 27\n", "1 98 100" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " Train accuracy : 0.8937074829931972\n", "\n", " Classification Report : \n", " precision recall f1-score support\n", "\n", " 0 0.91 0.97 0.94 978\n", " 1 0.79 0.51 0.62 198\n", "\n", "avg / total 0.89 0.89 0.88 1176\n", "\n" ] } ], "source": [ "# 1. Confusion Matrix\n", "print('\\n Train Confusion Matrix :')\n", "display(pd.crosstab(y_train, pred_train, rownames=['Actual'], colnames=['Predict']))\n", "\n", "# 2. Accuracy\n", "print('\\n Train accuracy :', accuracy_score(y_train, pred_train))\n", "\n", "# 3. Classification Report\n", "print('\\n Classification Report : \\n', classification_report(y_train, pred_train))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "테스트셋에 대한 결과값을 살펴보자." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Test Confusion Matrix :\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
023520
1336
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 235 20\n", "1 33 6" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " Test accuracy : 0.8197278911564626\n", "\n", " Classification Report : \n", " precision recall f1-score support\n", "\n", " 0 0.88 0.92 0.90 255\n", " 1 0.23 0.15 0.18 39\n", "\n", "avg / total 0.79 0.82 0.80 294\n", "\n" ] } ], "source": [ "# 1. Confusion Matrix\n", "print('\\n Test Confusion Matrix :')\n", "display(pd.crosstab(y_test, pred_test, rownames=['Actual'], colnames=['Predict']))\n", "\n", "# 2. Accuracy\n", "print('\\n Test accuracy :', accuracy_score(y_test, pred_test))\n", "\n", "# 3. Classification Report\n", "print('\\n Classification Report : \\n', classification_report(y_test, pred_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. 클래스 가중치 조절" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정확도가 86%로 높은 것은 그렇게 의미가 없다." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 1233\n", "1 237\n", "Name: Attrition_idx, dtype: int64" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets.Attrition_idx.value_counts()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8077858880778589" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1233-237)/1233" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "값을 살펴보면 1의 비율이 6:1이다. 따라서 분류기가 모든 샘플에 대하여 0이라고만 분류해도, 80.77%의 정확도를 얻을 수 있다. 1(퇴직자)에 대한 분류를 제대로 못하고 있다.\n", "\n", "퇴사할 가능성이 많은 직원에게 보너스를 많이 주어, 퇴사를 방지하는 것에 이 모델을 사용한다면, 심각한 문제가 될 수 있다. 모델은 퇴직하지 않는다고 예측했는데, 예측과 달리 퇴직한 직원의 비율이 상당히 높다.\n", "\n", "모델을 살짝 튜닝해보자. 클래스의 가중값을 조절해보자. 예를 들어, 부류 1(퇴직자)의 가중값을 올리면 실제로 퇴사할 특성이 있는 직원들을 더 잘 파악하게 되지만, 퇴사할 가능성이 없는 일부 직원들을 잠재적 퇴사자로 분류하게 된다. 즉, 퇴사를 더 잘 막을 수 있게 될 것이다. (대출을 실행할 때, 신용도 낮은 사람을 승인하는 것보다, 신용도가 조금 만족되는 사람이라도 거절하는 편이 더 낫다. 즉 감당할 수 있는 오류이다.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "클래스의 가중값을 바꾸어보며 테스트 한다." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "tuning_results = pd.DataFrame(np.empty((6, 10)))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "tuning_results.columns = ['class_0_weight', 'class_1_weight', \n", " 'train_accuracy', 'test_accuracy', \n", " 'precision_class_0', 'precision_class_1', 'precision_overall', \n", " 'recall_calss_0', 'recall_class_1', 'recall_overall']" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['precision', 'recall', 'f1-score', 'support', '0', '0.88', '0.92', '0.90', '255', '1', '0.23', '0.15', '0.18', '39', 'avg', '/', 'total', '0.79', '0.82', '0.80', '294']\n" ] } ], "source": [ "# 나중에 결과 테이블을 만들 떄 사용\n", "print(classification_report(y_test, pred_test).split())" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.01, 1: 0.99}\n", "Test accuracy : 0.2925170068027211\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
050205
1336
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 50 205\n", "1 3 36" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.1, 1: 0.9}\n", "Test accuracy : 0.6972789115646258\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
018372
11722
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 183 72\n", "1 17 22" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.2, 1: 0.8}\n", "Test accuracy : 0.7891156462585034\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
021639
12316
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 216 39\n", "1 23 16" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.3, 1: 0.7}\n", "Test accuracy : 0.8095238095238095\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
022629
12712
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 226 29\n", "1 27 12" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.4, 1: 0.6}\n", "Test accuracy : 0.7993197278911565\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
022530
12910
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 225 30\n", "1 29 10" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.5, 1: 0.5}\n", "Test accuracy : 0.8197278911564626\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
023520
1336
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 235 20\n", "1 33 6" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.6, 1: 0.4}\n", "Test accuracy : 0.8469387755102041\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
02478
1372
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 247 8\n", "1 37 2" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.7, 1: 0.30000000000000004}\n", "Test accuracy : 0.8537414965986394\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
02487
1363
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 248 7\n", "1 36 3" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.8, 1: 0.19999999999999996}\n", "Test accuracy : 0.8571428571428571\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
02505
1372
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 250 5\n", "1 37 2" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.9, 1: 0.09999999999999998}\n", "Test accuracy : 0.8673469387755102\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
02532
1372
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 253 2\n", "1 37 2" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0.99, 1: 0.010000000000000009}\n", "Test accuracy : 0.8707482993197279\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Predict01
Actual
02550
1381
\n", "
" ], "text/plain": [ "Predict 0 1\n", "Actual \n", "0 255 0\n", "1 38 1" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class_0_weight = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]\n", "\n", "for i in range(len(class_0_weight)):\n", " class_weights = {0: class_0_weight[i], 1: 1 - class_0_weight[i]}\n", " tree_classifier = DecisionTreeClassifier(criterion='gini',\n", " max_depth=5,\n", " min_samples_split=2,\n", " min_samples_leaf=1,\n", " random_state=42,\n", " class_weight=class_weights)\n", " tree_classifier.fit(x_train, y_train)\n", " pred_train = tree_classifier.predict(x_train)\n", " pred_test = tree_classifier.predict(x_test)\n", " tuning_results.loc[i, 'class_0_weight'] = class_weights[0]\n", " tuning_results.loc[i, 'class_1_weight'] = class_weights[1]\n", " tuning_results.loc[i, 'train_accuracy'] = round(accuracy_score(y_train, pred_train), 4)\n", " tuning_results.loc[i, 'test_accuracy'] = round(accuracy_score(y_test, pred_test), 4)\n", " c_r = classification_report(y_test, pred_test).split()\n", " tuning_results.loc[i, 'precision_class_0'] = float(c_r[5])\n", " tuning_results.loc[i, 'precision_class_1'] = float(c_r[10])\n", " tuning_results.loc[i, 'precision_overall'] = float(c_r[17])\n", " tuning_results.loc[i, 'recall_calss_0'] = float(c_r[6])\n", " tuning_results.loc[i, 'recall_class_1'] = float(c_r[11])\n", " tuning_results.loc[i, 'recall_overall'] = float(c_r[18])\n", "\n", " print(class_weights)\n", " print('Test accuracy :', accuracy_score(y_test, pred_test))\n", " display(pd.crosstab(y_test, pred_test, rownames=['Actual'], colnames=['Predict']))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
class_0_weightclass_1_weighttrain_accuracytest_accuracyprecision_class_0precision_class_1precision_overallrecall_calss_0recall_class_1recall_overall
00.010.990.35800.29250.940.150.840.200.920.29
10.100.900.79760.69730.920.230.820.720.560.70
20.200.800.87590.78910.900.290.820.850.410.79
30.300.700.89120.80950.890.290.810.890.310.81
40.400.600.89030.79930.890.250.800.880.260.80
50.500.500.89370.81970.880.230.790.920.150.82
60.600.400.89540.84690.870.200.780.970.050.85
70.700.300.89630.85370.870.300.800.970.080.85
80.800.200.88690.85710.870.290.790.980.050.86
90.900.100.86220.86730.870.500.820.990.050.87
100.990.010.84350.87070.871.000.891.000.030.87
\n", "
" ], "text/plain": [ " class_0_weight class_1_weight train_accuracy test_accuracy \\\n", "0 0.01 0.99 0.3580 0.2925 \n", "1 0.10 0.90 0.7976 0.6973 \n", "2 0.20 0.80 0.8759 0.7891 \n", "3 0.30 0.70 0.8912 0.8095 \n", "4 0.40 0.60 0.8903 0.7993 \n", "5 0.50 0.50 0.8937 0.8197 \n", "6 0.60 0.40 0.8954 0.8469 \n", "7 0.70 0.30 0.8963 0.8537 \n", "8 0.80 0.20 0.8869 0.8571 \n", "9 0.90 0.10 0.8622 0.8673 \n", "10 0.99 0.01 0.8435 0.8707 \n", "\n", " precision_class_0 precision_class_1 precision_overall recall_calss_0 \\\n", "0 0.94 0.15 0.84 0.20 \n", "1 0.92 0.23 0.82 0.72 \n", "2 0.90 0.29 0.82 0.85 \n", "3 0.89 0.29 0.81 0.89 \n", "4 0.89 0.25 0.80 0.88 \n", "5 0.88 0.23 0.79 0.92 \n", "6 0.87 0.20 0.78 0.97 \n", "7 0.87 0.30 0.80 0.97 \n", "8 0.87 0.29 0.79 0.98 \n", "9 0.87 0.50 0.82 0.99 \n", "10 0.87 1.00 0.89 1.00 \n", "\n", " recall_class_1 recall_overall \n", "0 0.92 0.29 \n", "1 0.56 0.70 \n", "2 0.41 0.79 \n", "3 0.31 0.81 \n", "4 0.26 0.80 \n", "5 0.15 0.82 \n", "6 0.05 0.85 \n", "7 0.08 0.85 \n", "8 0.05 0.86 \n", "9 0.05 0.87 \n", "10 0.03 0.87 " ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuning_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Class 0의 가중치가 커질수록, class 0으로 더 많이 예측한다. 우선 예측량이 많아지기 때문에 리콜은 상대적으로 높다. 정확도는 떨어질지라도 class 0으로 예측하는 양이 많아지기 때문에 recall이 상대적으로 높아지기 때문이다. 하지만 예측하는 양이 상대적으로 많이지기 때문에 precision은 떨어지게 된다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "결과를 살펴보면, class 0의 가중치가 0.3일 때, accuracy 및 precision, recall 이 괜찮은 결과를 보이고 있다." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }