{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 퇴직여부 분류기 생성\n",
"\n",
"- IBM에서 제공했던 HR 데이터를 활용\n",
"- 결정트리를 이용하여 어떤 사람이 퇴직을 여부를 분류할 수 있는 분류기 생성"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"pd.options.display.max_columns=None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 데이터 로딩"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"IBM kaggle 데이터 : https://www.kaggle.com/pavansubhasht/ibm-hr-analytics-attrition-dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" Attrition | \n",
" BusinessTravel | \n",
" DailyRate | \n",
" Department | \n",
" DistanceFromHome | \n",
" Education | \n",
" EducationField | \n",
" EmployeeCount | \n",
" EmployeeNumber | \n",
" EnvironmentSatisfaction | \n",
" Gender | \n",
" HourlyRate | \n",
" JobInvolvement | \n",
" JobLevel | \n",
" JobRole | \n",
" JobSatisfaction | \n",
" MaritalStatus | \n",
" MonthlyIncome | \n",
" MonthlyRate | \n",
" NumCompaniesWorked | \n",
" Over18 | \n",
" OverTime | \n",
" PercentSalaryHike | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StandardHours | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 41 | \n",
" Yes | \n",
" Travel_Rarely | \n",
" 1102 | \n",
" Sales | \n",
" 1 | \n",
" 2 | \n",
" Life Sciences | \n",
" 1 | \n",
" 1 | \n",
" 2 | \n",
" Female | \n",
" 94 | \n",
" 3 | \n",
" 2 | \n",
" Sales Executive | \n",
" 4 | \n",
" Single | \n",
" 5993 | \n",
" 19479 | \n",
" 8 | \n",
" Y | \n",
" Yes | \n",
" 11 | \n",
" 3 | \n",
" 1 | \n",
" 80 | \n",
" 0 | \n",
" 8 | \n",
" 0 | \n",
" 1 | \n",
" 6 | \n",
" 4 | \n",
" 0 | \n",
" 5 | \n",
"
\n",
" \n",
" | 1 | \n",
" 49 | \n",
" No | \n",
" Travel_Frequently | \n",
" 279 | \n",
" Research & Development | \n",
" 8 | \n",
" 1 | \n",
" Life Sciences | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" Male | \n",
" 61 | \n",
" 2 | \n",
" 2 | \n",
" Research Scientist | \n",
" 2 | \n",
" Married | \n",
" 5130 | \n",
" 24907 | \n",
" 1 | \n",
" Y | \n",
" No | \n",
" 23 | \n",
" 4 | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 10 | \n",
" 3 | \n",
" 3 | \n",
" 10 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
"
\n",
" \n",
" | 2 | \n",
" 37 | \n",
" Yes | \n",
" Travel_Rarely | \n",
" 1373 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Other | \n",
" 1 | \n",
" 4 | \n",
" 4 | \n",
" Male | \n",
" 92 | \n",
" 2 | \n",
" 1 | \n",
" Laboratory Technician | \n",
" 3 | \n",
" Single | \n",
" 2090 | \n",
" 2396 | \n",
" 6 | \n",
" Y | \n",
" Yes | \n",
" 15 | \n",
" 3 | \n",
" 2 | \n",
" 80 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 33 | \n",
" No | \n",
" Travel_Frequently | \n",
" 1392 | \n",
" Research & Development | \n",
" 3 | \n",
" 4 | \n",
" Life Sciences | \n",
" 1 | \n",
" 5 | \n",
" 4 | \n",
" Female | \n",
" 56 | \n",
" 3 | \n",
" 1 | \n",
" Research Scientist | \n",
" 3 | \n",
" Married | \n",
" 2909 | \n",
" 23159 | \n",
" 1 | \n",
" Y | \n",
" Yes | \n",
" 11 | \n",
" 3 | \n",
" 3 | \n",
" 80 | \n",
" 0 | \n",
" 8 | \n",
" 3 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 3 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 27 | \n",
" No | \n",
" Travel_Rarely | \n",
" 591 | \n",
" Research & Development | \n",
" 2 | \n",
" 1 | \n",
" Medical | \n",
" 1 | \n",
" 7 | \n",
" 1 | \n",
" Male | \n",
" 40 | \n",
" 3 | \n",
" 1 | \n",
" Laboratory Technician | \n",
" 2 | \n",
" Married | \n",
" 3468 | \n",
" 16632 | \n",
" 9 | \n",
" Y | \n",
" No | \n",
" 12 | \n",
" 3 | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Age Attrition BusinessTravel DailyRate Department \\\n",
"0 41 Yes Travel_Rarely 1102 Sales \n",
"1 49 No Travel_Frequently 279 Research & Development \n",
"2 37 Yes Travel_Rarely 1373 Research & Development \n",
"3 33 No Travel_Frequently 1392 Research & Development \n",
"4 27 No Travel_Rarely 591 Research & Development \n",
"\n",
" DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \\\n",
"0 1 2 Life Sciences 1 1 \n",
"1 8 1 Life Sciences 1 2 \n",
"2 2 2 Other 1 4 \n",
"3 3 4 Life Sciences 1 5 \n",
"4 2 1 Medical 1 7 \n",
"\n",
" EnvironmentSatisfaction Gender HourlyRate JobInvolvement JobLevel \\\n",
"0 2 Female 94 3 2 \n",
"1 3 Male 61 2 2 \n",
"2 4 Male 92 2 1 \n",
"3 4 Female 56 3 1 \n",
"4 1 Male 40 3 1 \n",
"\n",
" JobRole JobSatisfaction MaritalStatus MonthlyIncome \\\n",
"0 Sales Executive 4 Single 5993 \n",
"1 Research Scientist 2 Married 5130 \n",
"2 Laboratory Technician 3 Single 2090 \n",
"3 Research Scientist 3 Married 2909 \n",
"4 Laboratory Technician 2 Married 3468 \n",
"\n",
" MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike \\\n",
"0 19479 8 Y Yes 11 \n",
"1 24907 1 Y No 23 \n",
"2 2396 6 Y Yes 15 \n",
"3 23159 1 Y Yes 11 \n",
"4 16632 9 Y No 12 \n",
"\n",
" PerformanceRating RelationshipSatisfaction StandardHours \\\n",
"0 3 1 80 \n",
"1 4 4 80 \n",
"2 3 2 80 \n",
"3 3 3 80 \n",
"4 3 4 80 \n",
"\n",
" StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n",
"0 0 8 0 \n",
"1 1 10 3 \n",
"2 0 7 3 \n",
"3 0 8 3 \n",
"4 1 6 3 \n",
"\n",
" WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 1 6 4 \n",
"1 3 10 7 \n",
"2 3 0 0 \n",
"3 3 8 7 \n",
"4 3 2 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"0 0 5 \n",
"1 1 7 \n",
"2 0 0 \n",
"3 3 0 \n",
"4 2 2 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets = pd.read_csv('./inputs/HR-Employee-Attrition.csv')\n",
"datasets.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1470, 35)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"데이터를 살펴보면 categorycal features, numerical features가 함께 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Age int64\n",
"Attrition object\n",
"BusinessTravel object\n",
"DailyRate int64\n",
"Department object\n",
"DistanceFromHome int64\n",
"Education int64\n",
"EducationField object\n",
"EmployeeCount int64\n",
"EmployeeNumber int64\n",
"EnvironmentSatisfaction int64\n",
"Gender object\n",
"HourlyRate int64\n",
"JobInvolvement int64\n",
"JobLevel int64\n",
"JobRole object\n",
"JobSatisfaction int64\n",
"MaritalStatus object\n",
"MonthlyIncome int64\n",
"MonthlyRate int64\n",
"NumCompaniesWorked int64\n",
"Over18 object\n",
"OverTime object\n",
"PercentSalaryHike int64\n",
"PerformanceRating int64\n",
"RelationshipSatisfaction int64\n",
"StandardHours int64\n",
"StockOptionLevel int64\n",
"TotalWorkingYears int64\n",
"TrainingTimesLastYear int64\n",
"WorkLifeBalance int64\n",
"YearsAtCompany int64\n",
"YearsInCurrentRole int64\n",
"YearsSinceLastPromotion int64\n",
"YearsWithCurrManager int64\n",
"dtype: object"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets.dtypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Taget variable : *Attrition*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Yes / No -> 1 / 0으로 변경합니다\n",
"- 1 : 퇴직 Yes\n",
"- 0 : 퇴직 No"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" Attrition | \n",
" BusinessTravel | \n",
" DailyRate | \n",
" Department | \n",
" DistanceFromHome | \n",
" Education | \n",
" EducationField | \n",
" EmployeeCount | \n",
" EmployeeNumber | \n",
" EnvironmentSatisfaction | \n",
" Gender | \n",
" HourlyRate | \n",
" JobInvolvement | \n",
" JobLevel | \n",
" JobRole | \n",
" JobSatisfaction | \n",
" MaritalStatus | \n",
" MonthlyIncome | \n",
" MonthlyRate | \n",
" NumCompaniesWorked | \n",
" Over18 | \n",
" OverTime | \n",
" PercentSalaryHike | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StandardHours | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
" Attrition_idx | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 41 | \n",
" Yes | \n",
" Travel_Rarely | \n",
" 1102 | \n",
" Sales | \n",
" 1 | \n",
" 2 | \n",
" Life Sciences | \n",
" 1 | \n",
" 1 | \n",
" 2 | \n",
" Female | \n",
" 94 | \n",
" 3 | \n",
" 2 | \n",
" Sales Executive | \n",
" 4 | \n",
" Single | \n",
" 5993 | \n",
" 19479 | \n",
" 8 | \n",
" Y | \n",
" Yes | \n",
" 11 | \n",
" 3 | \n",
" 1 | \n",
" 80 | \n",
" 0 | \n",
" 8 | \n",
" 0 | \n",
" 1 | \n",
" 6 | \n",
" 4 | \n",
" 0 | \n",
" 5 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1 | \n",
" 49 | \n",
" No | \n",
" Travel_Frequently | \n",
" 279 | \n",
" Research & Development | \n",
" 8 | \n",
" 1 | \n",
" Life Sciences | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" Male | \n",
" 61 | \n",
" 2 | \n",
" 2 | \n",
" Research Scientist | \n",
" 2 | \n",
" Married | \n",
" 5130 | \n",
" 24907 | \n",
" 1 | \n",
" Y | \n",
" No | \n",
" 23 | \n",
" 4 | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 10 | \n",
" 3 | \n",
" 3 | \n",
" 10 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 37 | \n",
" Yes | \n",
" Travel_Rarely | \n",
" 1373 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Other | \n",
" 1 | \n",
" 4 | \n",
" 4 | \n",
" Male | \n",
" 92 | \n",
" 2 | \n",
" 1 | \n",
" Laboratory Technician | \n",
" 3 | \n",
" Single | \n",
" 2090 | \n",
" 2396 | \n",
" 6 | \n",
" Y | \n",
" Yes | \n",
" 15 | \n",
" 3 | \n",
" 2 | \n",
" 80 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 3 | \n",
" 33 | \n",
" No | \n",
" Travel_Frequently | \n",
" 1392 | \n",
" Research & Development | \n",
" 3 | \n",
" 4 | \n",
" Life Sciences | \n",
" 1 | \n",
" 5 | \n",
" 4 | \n",
" Female | \n",
" 56 | \n",
" 3 | \n",
" 1 | \n",
" Research Scientist | \n",
" 3 | \n",
" Married | \n",
" 2909 | \n",
" 23159 | \n",
" 1 | \n",
" Y | \n",
" Yes | \n",
" 11 | \n",
" 3 | \n",
" 3 | \n",
" 80 | \n",
" 0 | \n",
" 8 | \n",
" 3 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 27 | \n",
" No | \n",
" Travel_Rarely | \n",
" 591 | \n",
" Research & Development | \n",
" 2 | \n",
" 1 | \n",
" Medical | \n",
" 1 | \n",
" 7 | \n",
" 1 | \n",
" Male | \n",
" 40 | \n",
" 3 | \n",
" 1 | \n",
" Laboratory Technician | \n",
" 2 | \n",
" Married | \n",
" 3468 | \n",
" 16632 | \n",
" 9 | \n",
" Y | \n",
" No | \n",
" 12 | \n",
" 3 | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Age Attrition BusinessTravel DailyRate Department \\\n",
"0 41 Yes Travel_Rarely 1102 Sales \n",
"1 49 No Travel_Frequently 279 Research & Development \n",
"2 37 Yes Travel_Rarely 1373 Research & Development \n",
"3 33 No Travel_Frequently 1392 Research & Development \n",
"4 27 No Travel_Rarely 591 Research & Development \n",
"\n",
" DistanceFromHome Education EducationField EmployeeCount EmployeeNumber \\\n",
"0 1 2 Life Sciences 1 1 \n",
"1 8 1 Life Sciences 1 2 \n",
"2 2 2 Other 1 4 \n",
"3 3 4 Life Sciences 1 5 \n",
"4 2 1 Medical 1 7 \n",
"\n",
" EnvironmentSatisfaction Gender HourlyRate JobInvolvement JobLevel \\\n",
"0 2 Female 94 3 2 \n",
"1 3 Male 61 2 2 \n",
"2 4 Male 92 2 1 \n",
"3 4 Female 56 3 1 \n",
"4 1 Male 40 3 1 \n",
"\n",
" JobRole JobSatisfaction MaritalStatus MonthlyIncome \\\n",
"0 Sales Executive 4 Single 5993 \n",
"1 Research Scientist 2 Married 5130 \n",
"2 Laboratory Technician 3 Single 2090 \n",
"3 Research Scientist 3 Married 2909 \n",
"4 Laboratory Technician 2 Married 3468 \n",
"\n",
" MonthlyRate NumCompaniesWorked Over18 OverTime PercentSalaryHike \\\n",
"0 19479 8 Y Yes 11 \n",
"1 24907 1 Y No 23 \n",
"2 2396 6 Y Yes 15 \n",
"3 23159 1 Y Yes 11 \n",
"4 16632 9 Y No 12 \n",
"\n",
" PerformanceRating RelationshipSatisfaction StandardHours \\\n",
"0 3 1 80 \n",
"1 4 4 80 \n",
"2 3 2 80 \n",
"3 3 3 80 \n",
"4 3 4 80 \n",
"\n",
" StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n",
"0 0 8 0 \n",
"1 1 10 3 \n",
"2 0 7 3 \n",
"3 0 8 3 \n",
"4 1 6 3 \n",
"\n",
" WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 1 6 4 \n",
"1 3 10 7 \n",
"2 3 0 0 \n",
"3 3 8 7 \n",
"4 3 2 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager Attrition_idx \n",
"0 0 5 1 \n",
"1 1 7 0 \n",
"2 0 0 1 \n",
"3 3 0 0 \n",
"4 2 2 0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets['Attrition_idx'] = datasets['Attrition']\\\n",
" .apply(lambda x: 1 if x == 'Yes' else 0)\n",
"datasets.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Column 전처리"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',\n",
" 'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',\n",
" 'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',\n",
" 'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',\n",
" 'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',\n",
" 'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',\n",
" 'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',\n",
" 'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',\n",
" 'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',\n",
" 'YearsWithCurrManager', 'Attrition_idx'],\n",
" dtype='object')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"col_names = datasets.columns\n",
"col_names"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"필요없는 변수들이 있다 : *EmployeeCount*, *EmployeeNumber*, *Over18*, *StandardHours*"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Y 1470\n",
"Name: Over18, dtype: int64\n",
"1 1470\n",
"Name: EmployeeCount, dtype: int64\n",
"80 1470\n",
"Name: StandardHours, dtype: int64\n"
]
}
],
"source": [
"print(datasets.Over18.value_counts())\n",
"print(datasets.EmployeeCount.value_counts())\n",
"print(datasets.StandardHours.value_counts())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Target은 feature에서 제외한다.\n",
"col_names = col_names\\\n",
" .drop(['Attrition_idx', 'Attrition', 'Over18', \n",
" 'EmployeeCount', 'EmployeeNumber', 'StandardHours'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Categorical column을 다루어보자.\n",
"\n",
"Catagorical column을 numerical column을 나누어보자."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"categorical_features = []\n",
"numerical_features = []\n",
"target = 'Attrition_idx'\n",
"\n",
"# feature를 2가지 형태로 구분한다.\n",
"for col in col_names:\n",
" if datasets[col].dtype == 'O':\n",
" categorical_features.append(col)\n",
" else:\n",
" numerical_features.append(col)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Categorical feature의 수 : 7\n",
"Numerical feature의 수 : 23\n"
]
}
],
"source": [
"print('Categorical feature의 수 :', len(categorical_features))\n",
"print('Numerical feature의 수 :', len(numerical_features))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['BusinessTravel',\n",
" 'Department',\n",
" 'EducationField',\n",
" 'Gender',\n",
" 'JobRole',\n",
" 'MaritalStatus',\n",
" 'OverTime']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"categorical_features"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Age',\n",
" 'DailyRate',\n",
" 'DistanceFromHome',\n",
" 'Education',\n",
" 'EnvironmentSatisfaction',\n",
" 'HourlyRate',\n",
" 'JobInvolvement',\n",
" 'JobLevel',\n",
" 'JobSatisfaction',\n",
" 'MonthlyIncome',\n",
" 'MonthlyRate',\n",
" 'NumCompaniesWorked',\n",
" 'PercentSalaryHike',\n",
" 'PerformanceRating',\n",
" 'RelationshipSatisfaction',\n",
" 'StockOptionLevel',\n",
" 'TotalWorkingYears',\n",
" 'TrainingTimesLastYear',\n",
" 'WorkLifeBalance',\n",
" 'YearsAtCompany',\n",
" 'YearsInCurrentRole',\n",
" 'YearsSinceLastPromotion',\n",
" 'YearsWithCurrManager']"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numerical_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Categorical 데이터를 one-hot vector로 변경하자. Pandas에서 `get_dummies`를 이용하자.\n",
"- Train, test set을 구분하지 않고 원핫벡터를 만드는 경우 : 해당 feature의 모든 원소들을 아는 경우, 예를 들어 회사 부서, 국가 코드의 경우에 해당한다.\n",
"- Train, test set을 구분하고 train set으로만 원핫벡터를 만드는 경우 : 해당 feature의 원소가 test set에 없는 경우가 존재할 수 있다."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" BusinessTravel_Non-Travel | \n",
" BusinessTravel_Travel_Frequently | \n",
" BusinessTravel_Travel_Rarely | \n",
" Department_Human Resources | \n",
" Department_Research & Development | \n",
" Department_Sales | \n",
" EducationField_Human Resources | \n",
" EducationField_Life Sciences | \n",
" EducationField_Marketing | \n",
" EducationField_Medical | \n",
" EducationField_Other | \n",
" EducationField_Technical Degree | \n",
" Gender_Female | \n",
" Gender_Male | \n",
" JobRole_Healthcare Representative | \n",
" JobRole_Human Resources | \n",
" JobRole_Laboratory Technician | \n",
" JobRole_Manager | \n",
" JobRole_Manufacturing Director | \n",
" JobRole_Research Director | \n",
" JobRole_Research Scientist | \n",
" JobRole_Sales Executive | \n",
" JobRole_Sales Representative | \n",
" MaritalStatus_Divorced | \n",
" MaritalStatus_Married | \n",
" MaritalStatus_Single | \n",
" OverTime_No | \n",
" OverTime_Yes | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n",
"0 0 0 \n",
"1 0 1 \n",
"2 0 0 \n",
"3 0 1 \n",
"4 0 0 \n",
"\n",
" BusinessTravel_Travel_Rarely Department_Human Resources \\\n",
"0 1 0 \n",
"1 0 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 1 0 \n",
"\n",
" Department_Research & Development Department_Sales \\\n",
"0 0 1 \n",
"1 1 0 \n",
"2 1 0 \n",
"3 1 0 \n",
"4 1 0 \n",
"\n",
" EducationField_Human Resources EducationField_Life Sciences \\\n",
"0 0 1 \n",
"1 0 1 \n",
"2 0 0 \n",
"3 0 1 \n",
"4 0 0 \n",
"\n",
" EducationField_Marketing EducationField_Medical EducationField_Other \\\n",
"0 0 0 0 \n",
"1 0 0 0 \n",
"2 0 0 1 \n",
"3 0 0 0 \n",
"4 0 1 0 \n",
"\n",
" EducationField_Technical Degree Gender_Female Gender_Male \\\n",
"0 0 1 0 \n",
"1 0 0 1 \n",
"2 0 0 1 \n",
"3 0 1 0 \n",
"4 0 0 1 \n",
"\n",
" JobRole_Healthcare Representative JobRole_Human Resources \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Laboratory Technician JobRole_Manager \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 1 0 \n",
"\n",
" JobRole_Manufacturing Director JobRole_Research Director \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Research Scientist JobRole_Sales Executive \\\n",
"0 0 1 \n",
"1 1 0 \n",
"2 0 0 \n",
"3 1 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Sales Representative MaritalStatus_Divorced \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" MaritalStatus_Married MaritalStatus_Single OverTime_No OverTime_Yes \n",
"0 0 1 0 1 \n",
"1 1 0 1 0 \n",
"2 0 1 0 1 \n",
"3 1 0 0 1 \n",
"4 1 0 1 0 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"categorical_datasets = pd.get_dummies(datasets[categorical_features])\n",
"categorical_datasets.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" DailyRate | \n",
" DistanceFromHome | \n",
" Education | \n",
" EnvironmentSatisfaction | \n",
" HourlyRate | \n",
" JobInvolvement | \n",
" JobLevel | \n",
" JobSatisfaction | \n",
" MonthlyIncome | \n",
" MonthlyRate | \n",
" NumCompaniesWorked | \n",
" PercentSalaryHike | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 41 | \n",
" 1102 | \n",
" 1 | \n",
" 2 | \n",
" 2 | \n",
" 94 | \n",
" 3 | \n",
" 2 | \n",
" 4 | \n",
" 5993 | \n",
" 19479 | \n",
" 8 | \n",
" 11 | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 8 | \n",
" 0 | \n",
" 1 | \n",
" 6 | \n",
" 4 | \n",
" 0 | \n",
" 5 | \n",
"
\n",
" \n",
" | 1 | \n",
" 49 | \n",
" 279 | \n",
" 8 | \n",
" 1 | \n",
" 3 | \n",
" 61 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 5130 | \n",
" 24907 | \n",
" 1 | \n",
" 23 | \n",
" 4 | \n",
" 4 | \n",
" 1 | \n",
" 10 | \n",
" 3 | \n",
" 3 | \n",
" 10 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
"
\n",
" \n",
" | 2 | \n",
" 37 | \n",
" 1373 | \n",
" 2 | \n",
" 2 | \n",
" 4 | \n",
" 92 | \n",
" 2 | \n",
" 1 | \n",
" 3 | \n",
" 2090 | \n",
" 2396 | \n",
" 6 | \n",
" 15 | \n",
" 3 | \n",
" 2 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 33 | \n",
" 1392 | \n",
" 3 | \n",
" 4 | \n",
" 4 | \n",
" 56 | \n",
" 3 | \n",
" 1 | \n",
" 3 | \n",
" 2909 | \n",
" 23159 | \n",
" 1 | \n",
" 11 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 8 | \n",
" 3 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 3 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 27 | \n",
" 591 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
" 40 | \n",
" 3 | \n",
" 1 | \n",
" 2 | \n",
" 3468 | \n",
" 16632 | \n",
" 9 | \n",
" 12 | \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Age DailyRate DistanceFromHome Education EnvironmentSatisfaction \\\n",
"0 41 1102 1 2 2 \n",
"1 49 279 8 1 3 \n",
"2 37 1373 2 2 4 \n",
"3 33 1392 3 4 4 \n",
"4 27 591 2 1 1 \n",
"\n",
" HourlyRate JobInvolvement JobLevel JobSatisfaction MonthlyIncome \\\n",
"0 94 3 2 4 5993 \n",
"1 61 2 2 2 5130 \n",
"2 92 2 1 3 2090 \n",
"3 56 3 1 3 2909 \n",
"4 40 3 1 2 3468 \n",
"\n",
" MonthlyRate NumCompaniesWorked PercentSalaryHike PerformanceRating \\\n",
"0 19479 8 11 3 \n",
"1 24907 1 23 4 \n",
"2 2396 6 15 3 \n",
"3 23159 1 11 3 \n",
"4 16632 9 12 3 \n",
"\n",
" RelationshipSatisfaction StockOptionLevel TotalWorkingYears \\\n",
"0 1 0 8 \n",
"1 4 1 10 \n",
"2 2 0 7 \n",
"3 3 0 8 \n",
"4 4 1 6 \n",
"\n",
" TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 0 1 6 4 \n",
"1 3 3 10 7 \n",
"2 3 3 0 0 \n",
"3 3 3 8 7 \n",
"4 3 3 2 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"0 0 5 \n",
"1 1 7 \n",
"2 0 0 \n",
"3 3 0 \n",
"4 2 2 "
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numerical_datasets = datasets[numerical_features]\n",
"numerical_datasets.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Categorical dataset과 numerical dataset을 합친다. 모델의 input으로 사용할 feature이다."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" BusinessTravel_Non-Travel | \n",
" BusinessTravel_Travel_Frequently | \n",
" BusinessTravel_Travel_Rarely | \n",
" Department_Human Resources | \n",
" Department_Research & Development | \n",
" Department_Sales | \n",
" EducationField_Human Resources | \n",
" EducationField_Life Sciences | \n",
" EducationField_Marketing | \n",
" EducationField_Medical | \n",
" EducationField_Other | \n",
" EducationField_Technical Degree | \n",
" Gender_Female | \n",
" Gender_Male | \n",
" JobRole_Healthcare Representative | \n",
" JobRole_Human Resources | \n",
" JobRole_Laboratory Technician | \n",
" JobRole_Manager | \n",
" JobRole_Manufacturing Director | \n",
" JobRole_Research Director | \n",
" JobRole_Research Scientist | \n",
" JobRole_Sales Executive | \n",
" JobRole_Sales Representative | \n",
" MaritalStatus_Divorced | \n",
" MaritalStatus_Married | \n",
" MaritalStatus_Single | \n",
" OverTime_No | \n",
" OverTime_Yes | \n",
" Age | \n",
" DailyRate | \n",
" DistanceFromHome | \n",
" Education | \n",
" EnvironmentSatisfaction | \n",
" HourlyRate | \n",
" JobInvolvement | \n",
" JobLevel | \n",
" JobSatisfaction | \n",
" MonthlyIncome | \n",
" MonthlyRate | \n",
" NumCompaniesWorked | \n",
" PercentSalaryHike | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 41 | \n",
" 1102 | \n",
" 1 | \n",
" 2 | \n",
" 2 | \n",
" 94 | \n",
" 3 | \n",
" 2 | \n",
" 4 | \n",
" 5993 | \n",
" 19479 | \n",
" 8 | \n",
" 11 | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 8 | \n",
" 0 | \n",
" 1 | \n",
" 6 | \n",
" 4 | \n",
" 0 | \n",
" 5 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 49 | \n",
" 279 | \n",
" 8 | \n",
" 1 | \n",
" 3 | \n",
" 61 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 5130 | \n",
" 24907 | \n",
" 1 | \n",
" 23 | \n",
" 4 | \n",
" 4 | \n",
" 1 | \n",
" 10 | \n",
" 3 | \n",
" 3 | \n",
" 10 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 37 | \n",
" 1373 | \n",
" 2 | \n",
" 2 | \n",
" 4 | \n",
" 92 | \n",
" 2 | \n",
" 1 | \n",
" 3 | \n",
" 2090 | \n",
" 2396 | \n",
" 6 | \n",
" 15 | \n",
" 3 | \n",
" 2 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 33 | \n",
" 1392 | \n",
" 3 | \n",
" 4 | \n",
" 4 | \n",
" 56 | \n",
" 3 | \n",
" 1 | \n",
" 3 | \n",
" 2909 | \n",
" 23159 | \n",
" 1 | \n",
" 11 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 8 | \n",
" 3 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 3 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 27 | \n",
" 591 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
" 40 | \n",
" 3 | \n",
" 1 | \n",
" 2 | \n",
" 3468 | \n",
" 16632 | \n",
" 9 | \n",
" 12 | \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n",
"0 0 0 \n",
"1 0 1 \n",
"2 0 0 \n",
"3 0 1 \n",
"4 0 0 \n",
"\n",
" BusinessTravel_Travel_Rarely Department_Human Resources \\\n",
"0 1 0 \n",
"1 0 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 1 0 \n",
"\n",
" Department_Research & Development Department_Sales \\\n",
"0 0 1 \n",
"1 1 0 \n",
"2 1 0 \n",
"3 1 0 \n",
"4 1 0 \n",
"\n",
" EducationField_Human Resources EducationField_Life Sciences \\\n",
"0 0 1 \n",
"1 0 1 \n",
"2 0 0 \n",
"3 0 1 \n",
"4 0 0 \n",
"\n",
" EducationField_Marketing EducationField_Medical EducationField_Other \\\n",
"0 0 0 0 \n",
"1 0 0 0 \n",
"2 0 0 1 \n",
"3 0 0 0 \n",
"4 0 1 0 \n",
"\n",
" EducationField_Technical Degree Gender_Female Gender_Male \\\n",
"0 0 1 0 \n",
"1 0 0 1 \n",
"2 0 0 1 \n",
"3 0 1 0 \n",
"4 0 0 1 \n",
"\n",
" JobRole_Healthcare Representative JobRole_Human Resources \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Laboratory Technician JobRole_Manager \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 1 0 \n",
"\n",
" JobRole_Manufacturing Director JobRole_Research Director \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Research Scientist JobRole_Sales Executive \\\n",
"0 0 1 \n",
"1 1 0 \n",
"2 0 0 \n",
"3 1 0 \n",
"4 0 0 \n",
"\n",
" JobRole_Sales Representative MaritalStatus_Divorced \\\n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 \n",
"\n",
" MaritalStatus_Married MaritalStatus_Single OverTime_No OverTime_Yes \\\n",
"0 0 1 0 1 \n",
"1 1 0 1 0 \n",
"2 0 1 0 1 \n",
"3 1 0 0 1 \n",
"4 1 0 1 0 \n",
"\n",
" Age DailyRate DistanceFromHome Education EnvironmentSatisfaction \\\n",
"0 41 1102 1 2 2 \n",
"1 49 279 8 1 3 \n",
"2 37 1373 2 2 4 \n",
"3 33 1392 3 4 4 \n",
"4 27 591 2 1 1 \n",
"\n",
" HourlyRate JobInvolvement JobLevel JobSatisfaction MonthlyIncome \\\n",
"0 94 3 2 4 5993 \n",
"1 61 2 2 2 5130 \n",
"2 92 2 1 3 2090 \n",
"3 56 3 1 3 2909 \n",
"4 40 3 1 2 3468 \n",
"\n",
" MonthlyRate NumCompaniesWorked PercentSalaryHike PerformanceRating \\\n",
"0 19479 8 11 3 \n",
"1 24907 1 23 4 \n",
"2 2396 6 15 3 \n",
"3 23159 1 11 3 \n",
"4 16632 9 12 3 \n",
"\n",
" RelationshipSatisfaction StockOptionLevel TotalWorkingYears \\\n",
"0 1 0 8 \n",
"1 4 1 10 \n",
"2 2 0 7 \n",
"3 3 0 8 \n",
"4 4 1 6 \n",
"\n",
" TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 0 1 6 4 \n",
"1 3 3 10 7 \n",
"2 3 3 0 0 \n",
"3 3 3 8 7 \n",
"4 3 3 2 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"0 0 5 \n",
"1 1 7 \n",
"2 0 0 \n",
"3 3 0 \n",
"4 2 2 "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = pd.concat([categorical_datasets, numerical_datasets], axis=1)\n",
"X.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 1\n",
"1 0\n",
"2 1\n",
"3 0\n",
"4 0\n",
"Name: Attrition_idx, dtype: int64"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = datasets[target]\n",
"y.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Train set과 test set을 구분"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"x_train, x_test, y_train, y_test = \\\n",
" train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 모델 학습 및 hyperparameter 찾기"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Decision tree를 이용하여 분류기를 생성합니다. 아래의 파라미터를 grid search로 찾습니다.\n",
" - 트리의 최대깊이\n",
" - 분할을 위한 최소 관측값\n",
" - 각 단말 도느에서 필요한 최소 관측 값"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.model_selection import GridSearchCV"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 12 candidates, totalling 36 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Done 36 out of 36 | elapsed: 0.2s finished\n"
]
},
{
"data": {
"text/plain": [
"GridSearchCV(cv=3, error_score='raise',\n",
" estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
" max_features=None, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, presort=False, random_state=42,\n",
" splitter='best'),\n",
" fit_params=None, iid=True, n_jobs=-1,\n",
" param_grid={'max_depth': [5, 7, 9], 'min_samples_split': [2], 'min_samples_leaf': [1, 2, 3, 4]},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
" scoring=None, verbose=1)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params = {\n",
" 'max_depth': [5,7,9],\n",
" 'min_samples_split': [2], \n",
" 'min_samples_leaf': [1, 2, 3, 4]\n",
"}\n",
"\n",
"grid_search_cv = \\\n",
" GridSearchCV(\n",
" DecisionTreeClassifier(random_state=42), \n",
" params, \n",
" n_jobs=-1, \n",
" verbose=1, \n",
" cv=3)\n",
"\n",
"grid_search_cv.fit(x_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,\n",
" max_features=None, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, presort=False, random_state=42,\n",
" splitter='best')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 검색 결과, 좋은 결과값을 주는 파라미터를 가진 tree 모델을 찾습니다.\n",
"tree_classifier = grid_search_cv.best_estimator_\n",
"tree_classifier"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8273809523809523"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Score를 살펴봅니다.\n",
"grid_search_cv.best_score_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 모델 성능 측정"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"pred_train = tree_classifier.predict(x_train)\n",
"pred_test = tree_classifier.predict(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score, classification_report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"학습셋에 대한 결과값을 살펴보자."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" Train Confusion Matrix :\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 951 | \n",
" 27 | \n",
"
\n",
" \n",
" | 1 | \n",
" 98 | \n",
" 100 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 951 27\n",
"1 98 100"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" Train accuracy : 0.8937074829931972\n",
"\n",
" Classification Report : \n",
" precision recall f1-score support\n",
"\n",
" 0 0.91 0.97 0.94 978\n",
" 1 0.79 0.51 0.62 198\n",
"\n",
"avg / total 0.89 0.89 0.88 1176\n",
"\n"
]
}
],
"source": [
"# 1. Confusion Matrix\n",
"print('\\n Train Confusion Matrix :')\n",
"display(pd.crosstab(y_train, pred_train, rownames=['Actual'], colnames=['Predict']))\n",
"\n",
"# 2. Accuracy\n",
"print('\\n Train accuracy :', accuracy_score(y_train, pred_train))\n",
"\n",
"# 3. Classification Report\n",
"print('\\n Classification Report : \\n', classification_report(y_train, pred_train))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"테스트셋에 대한 결과값을 살펴보자."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" Test Confusion Matrix :\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 235 | \n",
" 20 | \n",
"
\n",
" \n",
" | 1 | \n",
" 33 | \n",
" 6 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 235 20\n",
"1 33 6"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" Test accuracy : 0.8197278911564626\n",
"\n",
" Classification Report : \n",
" precision recall f1-score support\n",
"\n",
" 0 0.88 0.92 0.90 255\n",
" 1 0.23 0.15 0.18 39\n",
"\n",
"avg / total 0.79 0.82 0.80 294\n",
"\n"
]
}
],
"source": [
"# 1. Confusion Matrix\n",
"print('\\n Test Confusion Matrix :')\n",
"display(pd.crosstab(y_test, pred_test, rownames=['Actual'], colnames=['Predict']))\n",
"\n",
"# 2. Accuracy\n",
"print('\\n Test accuracy :', accuracy_score(y_test, pred_test))\n",
"\n",
"# 3. Classification Report\n",
"print('\\n Classification Report : \\n', classification_report(y_test, pred_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. 클래스 가중치 조절"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"정확도가 86%로 높은 것은 그렇게 의미가 없다."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 1233\n",
"1 237\n",
"Name: Attrition_idx, dtype: int64"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets.Attrition_idx.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8077858880778589"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(1233-237)/1233"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"값을 살펴보면 1의 비율이 6:1이다. 따라서 분류기가 모든 샘플에 대하여 0이라고만 분류해도, 80.77%의 정확도를 얻을 수 있다. 1(퇴직자)에 대한 분류를 제대로 못하고 있다.\n",
"\n",
"퇴사할 가능성이 많은 직원에게 보너스를 많이 주어, 퇴사를 방지하는 것에 이 모델을 사용한다면, 심각한 문제가 될 수 있다. 모델은 퇴직하지 않는다고 예측했는데, 예측과 달리 퇴직한 직원의 비율이 상당히 높다.\n",
"\n",
"모델을 살짝 튜닝해보자. 클래스의 가중값을 조절해보자. 예를 들어, 부류 1(퇴직자)의 가중값을 올리면 실제로 퇴사할 특성이 있는 직원들을 더 잘 파악하게 되지만, 퇴사할 가능성이 없는 일부 직원들을 잠재적 퇴사자로 분류하게 된다. 즉, 퇴사를 더 잘 막을 수 있게 될 것이다. (대출을 실행할 때, 신용도 낮은 사람을 승인하는 것보다, 신용도가 조금 만족되는 사람이라도 거절하는 편이 더 낫다. 즉 감당할 수 있는 오류이다.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"클래스의 가중값을 바꾸어보며 테스트 한다."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"tuning_results = pd.DataFrame(np.empty((6, 10)))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"tuning_results.columns = ['class_0_weight', 'class_1_weight', \n",
" 'train_accuracy', 'test_accuracy', \n",
" 'precision_class_0', 'precision_class_1', 'precision_overall', \n",
" 'recall_calss_0', 'recall_class_1', 'recall_overall']"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['precision', 'recall', 'f1-score', 'support', '0', '0.88', '0.92', '0.90', '255', '1', '0.23', '0.15', '0.18', '39', 'avg', '/', 'total', '0.79', '0.82', '0.80', '294']\n"
]
}
],
"source": [
"# 나중에 결과 테이블을 만들 떄 사용\n",
"print(classification_report(y_test, pred_test).split())"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.01, 1: 0.99}\n",
"Test accuracy : 0.2925170068027211\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 50 | \n",
" 205 | \n",
"
\n",
" \n",
" | 1 | \n",
" 3 | \n",
" 36 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 50 205\n",
"1 3 36"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.1, 1: 0.9}\n",
"Test accuracy : 0.6972789115646258\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 183 | \n",
" 72 | \n",
"
\n",
" \n",
" | 1 | \n",
" 17 | \n",
" 22 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 183 72\n",
"1 17 22"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.2, 1: 0.8}\n",
"Test accuracy : 0.7891156462585034\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 216 | \n",
" 39 | \n",
"
\n",
" \n",
" | 1 | \n",
" 23 | \n",
" 16 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 216 39\n",
"1 23 16"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.3, 1: 0.7}\n",
"Test accuracy : 0.8095238095238095\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 226 | \n",
" 29 | \n",
"
\n",
" \n",
" | 1 | \n",
" 27 | \n",
" 12 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 226 29\n",
"1 27 12"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.4, 1: 0.6}\n",
"Test accuracy : 0.7993197278911565\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 225 | \n",
" 30 | \n",
"
\n",
" \n",
" | 1 | \n",
" 29 | \n",
" 10 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 225 30\n",
"1 29 10"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.5, 1: 0.5}\n",
"Test accuracy : 0.8197278911564626\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 235 | \n",
" 20 | \n",
"
\n",
" \n",
" | 1 | \n",
" 33 | \n",
" 6 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 235 20\n",
"1 33 6"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.6, 1: 0.4}\n",
"Test accuracy : 0.8469387755102041\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 247 | \n",
" 8 | \n",
"
\n",
" \n",
" | 1 | \n",
" 37 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 247 8\n",
"1 37 2"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.7, 1: 0.30000000000000004}\n",
"Test accuracy : 0.8537414965986394\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 248 | \n",
" 7 | \n",
"
\n",
" \n",
" | 1 | \n",
" 36 | \n",
" 3 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 248 7\n",
"1 36 3"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.8, 1: 0.19999999999999996}\n",
"Test accuracy : 0.8571428571428571\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 250 | \n",
" 5 | \n",
"
\n",
" \n",
" | 1 | \n",
" 37 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 250 5\n",
"1 37 2"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.9, 1: 0.09999999999999998}\n",
"Test accuracy : 0.8673469387755102\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 253 | \n",
" 2 | \n",
"
\n",
" \n",
" | 1 | \n",
" 37 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 253 2\n",
"1 37 2"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 0.99, 1: 0.010000000000000009}\n",
"Test accuracy : 0.8707482993197279\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | Predict | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" | Actual | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 255 | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 38 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Predict 0 1\n",
"Actual \n",
"0 255 0\n",
"1 38 1"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class_0_weight = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]\n",
"\n",
"for i in range(len(class_0_weight)):\n",
" class_weights = {0: class_0_weight[i], 1: 1 - class_0_weight[i]}\n",
" tree_classifier = DecisionTreeClassifier(criterion='gini',\n",
" max_depth=5,\n",
" min_samples_split=2,\n",
" min_samples_leaf=1,\n",
" random_state=42,\n",
" class_weight=class_weights)\n",
" tree_classifier.fit(x_train, y_train)\n",
" pred_train = tree_classifier.predict(x_train)\n",
" pred_test = tree_classifier.predict(x_test)\n",
" tuning_results.loc[i, 'class_0_weight'] = class_weights[0]\n",
" tuning_results.loc[i, 'class_1_weight'] = class_weights[1]\n",
" tuning_results.loc[i, 'train_accuracy'] = round(accuracy_score(y_train, pred_train), 4)\n",
" tuning_results.loc[i, 'test_accuracy'] = round(accuracy_score(y_test, pred_test), 4)\n",
" c_r = classification_report(y_test, pred_test).split()\n",
" tuning_results.loc[i, 'precision_class_0'] = float(c_r[5])\n",
" tuning_results.loc[i, 'precision_class_1'] = float(c_r[10])\n",
" tuning_results.loc[i, 'precision_overall'] = float(c_r[17])\n",
" tuning_results.loc[i, 'recall_calss_0'] = float(c_r[6])\n",
" tuning_results.loc[i, 'recall_class_1'] = float(c_r[11])\n",
" tuning_results.loc[i, 'recall_overall'] = float(c_r[18])\n",
"\n",
" print(class_weights)\n",
" print('Test accuracy :', accuracy_score(y_test, pred_test))\n",
" display(pd.crosstab(y_test, pred_test, rownames=['Actual'], colnames=['Predict']))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" class_0_weight | \n",
" class_1_weight | \n",
" train_accuracy | \n",
" test_accuracy | \n",
" precision_class_0 | \n",
" precision_class_1 | \n",
" precision_overall | \n",
" recall_calss_0 | \n",
" recall_class_1 | \n",
" recall_overall | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.01 | \n",
" 0.99 | \n",
" 0.3580 | \n",
" 0.2925 | \n",
" 0.94 | \n",
" 0.15 | \n",
" 0.84 | \n",
" 0.20 | \n",
" 0.92 | \n",
" 0.29 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.10 | \n",
" 0.90 | \n",
" 0.7976 | \n",
" 0.6973 | \n",
" 0.92 | \n",
" 0.23 | \n",
" 0.82 | \n",
" 0.72 | \n",
" 0.56 | \n",
" 0.70 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.20 | \n",
" 0.80 | \n",
" 0.8759 | \n",
" 0.7891 | \n",
" 0.90 | \n",
" 0.29 | \n",
" 0.82 | \n",
" 0.85 | \n",
" 0.41 | \n",
" 0.79 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.30 | \n",
" 0.70 | \n",
" 0.8912 | \n",
" 0.8095 | \n",
" 0.89 | \n",
" 0.29 | \n",
" 0.81 | \n",
" 0.89 | \n",
" 0.31 | \n",
" 0.81 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.40 | \n",
" 0.60 | \n",
" 0.8903 | \n",
" 0.7993 | \n",
" 0.89 | \n",
" 0.25 | \n",
" 0.80 | \n",
" 0.88 | \n",
" 0.26 | \n",
" 0.80 | \n",
"
\n",
" \n",
" | 5 | \n",
" 0.50 | \n",
" 0.50 | \n",
" 0.8937 | \n",
" 0.8197 | \n",
" 0.88 | \n",
" 0.23 | \n",
" 0.79 | \n",
" 0.92 | \n",
" 0.15 | \n",
" 0.82 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.60 | \n",
" 0.40 | \n",
" 0.8954 | \n",
" 0.8469 | \n",
" 0.87 | \n",
" 0.20 | \n",
" 0.78 | \n",
" 0.97 | \n",
" 0.05 | \n",
" 0.85 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.70 | \n",
" 0.30 | \n",
" 0.8963 | \n",
" 0.8537 | \n",
" 0.87 | \n",
" 0.30 | \n",
" 0.80 | \n",
" 0.97 | \n",
" 0.08 | \n",
" 0.85 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.80 | \n",
" 0.20 | \n",
" 0.8869 | \n",
" 0.8571 | \n",
" 0.87 | \n",
" 0.29 | \n",
" 0.79 | \n",
" 0.98 | \n",
" 0.05 | \n",
" 0.86 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.90 | \n",
" 0.10 | \n",
" 0.8622 | \n",
" 0.8673 | \n",
" 0.87 | \n",
" 0.50 | \n",
" 0.82 | \n",
" 0.99 | \n",
" 0.05 | \n",
" 0.87 | \n",
"
\n",
" \n",
" | 10 | \n",
" 0.99 | \n",
" 0.01 | \n",
" 0.8435 | \n",
" 0.8707 | \n",
" 0.87 | \n",
" 1.00 | \n",
" 0.89 | \n",
" 1.00 | \n",
" 0.03 | \n",
" 0.87 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" class_0_weight class_1_weight train_accuracy test_accuracy \\\n",
"0 0.01 0.99 0.3580 0.2925 \n",
"1 0.10 0.90 0.7976 0.6973 \n",
"2 0.20 0.80 0.8759 0.7891 \n",
"3 0.30 0.70 0.8912 0.8095 \n",
"4 0.40 0.60 0.8903 0.7993 \n",
"5 0.50 0.50 0.8937 0.8197 \n",
"6 0.60 0.40 0.8954 0.8469 \n",
"7 0.70 0.30 0.8963 0.8537 \n",
"8 0.80 0.20 0.8869 0.8571 \n",
"9 0.90 0.10 0.8622 0.8673 \n",
"10 0.99 0.01 0.8435 0.8707 \n",
"\n",
" precision_class_0 precision_class_1 precision_overall recall_calss_0 \\\n",
"0 0.94 0.15 0.84 0.20 \n",
"1 0.92 0.23 0.82 0.72 \n",
"2 0.90 0.29 0.82 0.85 \n",
"3 0.89 0.29 0.81 0.89 \n",
"4 0.89 0.25 0.80 0.88 \n",
"5 0.88 0.23 0.79 0.92 \n",
"6 0.87 0.20 0.78 0.97 \n",
"7 0.87 0.30 0.80 0.97 \n",
"8 0.87 0.29 0.79 0.98 \n",
"9 0.87 0.50 0.82 0.99 \n",
"10 0.87 1.00 0.89 1.00 \n",
"\n",
" recall_class_1 recall_overall \n",
"0 0.92 0.29 \n",
"1 0.56 0.70 \n",
"2 0.41 0.79 \n",
"3 0.31 0.81 \n",
"4 0.26 0.80 \n",
"5 0.15 0.82 \n",
"6 0.05 0.85 \n",
"7 0.08 0.85 \n",
"8 0.05 0.86 \n",
"9 0.05 0.87 \n",
"10 0.03 0.87 "
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tuning_results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Class 0의 가중치가 커질수록, class 0으로 더 많이 예측한다. 우선 예측량이 많아지기 때문에 리콜은 상대적으로 높다. 정확도는 떨어질지라도 class 0으로 예측하는 양이 많아지기 때문에 recall이 상대적으로 높아지기 때문이다. 하지만 예측하는 양이 상대적으로 많이지기 때문에 precision은 떨어지게 된다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"결과를 살펴보면, class 0의 가중치가 0.3일 때, accuracy 및 precision, recall 이 괜찮은 결과를 보이고 있다."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}