{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Titanic\n",
"\n",
"학습목표 : Titanic의 탑승자 정보를 통해 생존자를 예측하는 모델 만들기"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import findspark\n",
"findspark.init()\n",
"import pyspark\n",
"sc = pyspark.SparkContext()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- PassengerId: string (nullable = true)\n",
" |-- Survived: double (nullable = true)\n",
" |-- Pclass: double (nullable = true)\n",
" |-- Name: string (nullable = true)\n",
" |-- Sex: string (nullable = true)\n",
" |-- Age: string (nullable = true)\n",
" |-- SibSp: double (nullable = true)\n",
" |-- Parch: double (nullable = true)\n",
" |-- Ticket: string (nullable = true)\n",
" |-- Fare: string (nullable = true)\n",
" |-- Cabin: string (nullable = true)\n",
" |-- Embarked: string (nullable = false)\n",
" |-- label: double (nullable = true)\n",
"\n"
]
}
],
"source": [
"from pyspark.sql.session import SparkSession\n",
"from pyspark.sql.functions import *\n",
"from pyspark.sql.types import *\n",
"\n",
"spark = SparkSession(sc)\n",
"titanic = spark.read.option(\"header\", \"true\").csv(\"/Users/ryanshin/Downloads/train.csv\") \\\n",
" .withColumn(\"Survived\", col(\"Survived\").cast(\"double\")) \\\n",
" .withColumn(\"label\", col(\"Survived\")) \\\n",
" .withColumn(\"Pclass\", col(\"Pclass\").cast(\"double\"))\\\n",
" .withColumn(\"SibSp\", col(\"SibSp\").cast(\"double\"))\\\n",
" .withColumn(\"Parch\", col(\"Parch\").cast(\"double\"))\\\n",
" .na.fill(\"S\", \"Embarked\")\n",
"titanic.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n",
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|label|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n",
"| 1| 0.0| 3.0|Braund, Mr. Owen ...| male| 22| 1.0| 0.0| A/5 21171| 7.25| null| S| 0.0|\n",
"| 2| 1.0| 1.0|Cumings, Mrs. Joh...|female| 38| 1.0| 0.0| PC 17599|71.2833| C85| C| 1.0|\n",
"| 3| 1.0| 3.0|Heikkinen, Miss. ...|female| 26| 0.0| 0.0|STON/O2. 3101282| 7.925| null| S| 1.0|\n",
"| 4| 1.0| 1.0|Futrelle, Mrs. Ja...|female| 35| 1.0| 0.0| 113803| 53.1| C123| S| 1.0|\n",
"| 5| 0.0| 3.0|Allen, Mr. Willia...| male| 35| 0.0| 0.0| 373450| 8.05| null| S| 0.0|\n",
"| 6| 0.0| 3.0| Moran, Mr. James| male|null| 0.0| 0.0| 330877| 8.4583| null| Q| 0.0|\n",
"| 7| 0.0| 1.0|McCarthy, Mr. Tim...| male| 54| 0.0| 0.0| 17463|51.8625| E46| S| 0.0|\n",
"| 8| 0.0| 3.0|Palsson, Master. ...| male| 2| 3.0| 1.0| 349909| 21.075| null| S| 0.0|\n",
"| 9| 1.0| 3.0|Johnson, Mrs. Osc...|female| 27| 0.0| 2.0| 347742|11.1333| null| S| 1.0|\n",
"| 10| 1.0| 2.0|Nasser, Mrs. Nich...|female| 14| 1.0| 0.0| 237736|30.0708| null| C| 1.0|\n",
"| 11| 1.0| 3.0|Sandstrom, Miss. ...|female| 4| 1.0| 1.0| PP 9549| 16.7| G6| S| 1.0|\n",
"| 12| 1.0| 1.0|Bonnell, Miss. El...|female| 58| 0.0| 0.0| 113783| 26.55| C103| S| 1.0|\n",
"| 13| 0.0| 3.0|Saundercock, Mr. ...| male| 20| 0.0| 0.0| A/5. 2151| 8.05| null| S| 0.0|\n",
"| 14| 0.0| 3.0|Andersson, Mr. An...| male| 39| 1.0| 5.0| 347082| 31.275| null| S| 0.0|\n",
"| 15| 0.0| 3.0|Vestrom, Miss. Hu...|female| 14| 0.0| 0.0| 350406| 7.8542| null| S| 0.0|\n",
"| 16| 1.0| 2.0|Hewlett, Mrs. (Ma...|female| 55| 0.0| 0.0| 248706| 16| null| S| 1.0|\n",
"| 17| 0.0| 3.0|Rice, Master. Eugene| male| 2| 4.0| 1.0| 382652| 29.125| null| Q| 0.0|\n",
"| 18| 1.0| 2.0|Williams, Mr. Cha...| male|null| 0.0| 0.0| 244373| 13| null| S| 1.0|\n",
"| 19| 0.0| 3.0|Vander Planke, Mr...|female| 31| 1.0| 0.0| 345763| 18| null| S| 0.0|\n",
"| 20| 1.0| 3.0|Masselmani, Mrs. ...|female|null| 0.0| 0.0| 2649| 7.225| null| C| 1.0|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"titanic.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 데이터 타입\n",
"* 출처 : https://www.kaggle.com/c/titanic/data\n",
"* Survived : 살았으면 1, 죽었으면 0\n",
"* SibSp : 형제자매나 배우자가 몇명 있는지?\n",
"* Parch : 자식이 몇명인지?"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+-------------+------------------------------------+\n",
"|count(PassengerId)|sum(Survived)|(sum(Survived) / count(PassengerId))|\n",
"+------------------+-------------+------------------------------------+\n",
"| 891| 342.0| 0.3838383838383838|\n",
"+------------------+-------------+------------------------------------+\n",
"\n"
]
}
],
"source": [
"titanic.select(count(\"PassengerId\"), sum(\"Survived\"), sum(\"Survived\")/count(\"PassengerId\")).show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 생존확율\n",
"* 생존자수( sum(\"Survived\") ) /전체승객수( count(\"PassengerId\") ) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+-----+\n",
"|Survived|count|\n",
"+--------+-----+\n",
"| 0.0| 549|\n",
"| 1.0| 342|\n",
"+--------+-----+\n",
"\n"
]
}
],
"source": [
"titanic.groupBy(\"Survived\").count().show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+--------+-----+\n",
"|Pclass|Survived|count|\n",
"+------+--------+-----+\n",
"| 1.0| 0.0| 80|\n",
"| 1.0| 1.0| 136|\n",
"| 2.0| 0.0| 97|\n",
"| 2.0| 1.0| 87|\n",
"| 3.0| 0.0| 372|\n",
"| 3.0| 1.0| 119|\n",
"+------+--------+-----+\n",
"\n"
]
}
],
"source": [
"titanic.groupBy(\"Pclass\", \"Survived\").count().orderBy(\"Pclass\", \"Survived\").show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+--------+-----+\n",
"| Sex|Survived|count|\n",
"+------+--------+-----+\n",
"|female| 0.0| 81|\n",
"|female| 1.0| 233|\n",
"| male| 0.0| 468|\n",
"| male| 1.0| 109|\n",
"+------+--------+-----+\n",
"\n"
]
}
],
"source": [
"titanic.groupBy(\"Sex\", \"Survived\").count().orderBy(\"Sex\", \"Survived\").show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+--------+-----+\n",
"|SibSp|Survived|count|\n",
"+-----+--------+-----+\n",
"| 0.0| 0.0| 398|\n",
"| 0.0| 1.0| 210|\n",
"| 1.0| 0.0| 97|\n",
"| 1.0| 1.0| 112|\n",
"| 2.0| 0.0| 15|\n",
"| 2.0| 1.0| 13|\n",
"| 3.0| 0.0| 12|\n",
"| 3.0| 1.0| 4|\n",
"| 4.0| 0.0| 15|\n",
"| 4.0| 1.0| 3|\n",
"| 5.0| 0.0| 5|\n",
"| 8.0| 0.0| 7|\n",
"+-----+--------+-----+\n",
"\n"
]
}
],
"source": [
"titanic.groupBy(\"SibSp\", \"Survived\").count().orderBy(\"SibSp\", \"Survived\").show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+--------+-----+\n",
"|Parch|Survived|count|\n",
"+-----+--------+-----+\n",
"| 0.0| 0.0| 445|\n",
"| 0.0| 1.0| 233|\n",
"| 1.0| 0.0| 53|\n",
"| 1.0| 1.0| 65|\n",
"| 2.0| 0.0| 40|\n",
"| 2.0| 1.0| 40|\n",
"| 3.0| 0.0| 2|\n",
"| 3.0| 1.0| 3|\n",
"| 4.0| 0.0| 4|\n",
"| 5.0| 0.0| 4|\n",
"| 5.0| 1.0| 1|\n",
"| 6.0| 0.0| 1|\n",
"+-----+--------+-----+\n",
"\n"
]
}
],
"source": [
"titanic.groupBy(\"Parch\", \"Survived\").count().orderBy(\"Parch\", \"Survived\").show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 다 죽었다고 예측\n",
"def predict1_func():\n",
" return 0.0\n",
"predict1 = udf(predict1_func, returnType=DoubleType())\n",
" \n",
"# 여자는 다 살았다고 남자는 다 죽었다고 예측\n",
"def predict2_func(gender):\n",
" if gender == \"female\":\n",
" return 1.0\n",
" else:\n",
" return 0.0 \n",
"predict2 = udf(predict2_func, returnType=DoubleType())\n",
" \n",
"# UDF 생성\n",
"prediction1result = titanic.select(predict1().alias(\"prediction\"), col(\"Survived\").cast(\"double\").alias(\"label\"))\n",
"prediction2result = titanic.select(predict2(\"Sex\").alias(\"prediction\"), col(\"Survived\").cast(\"double\").alias(\"label\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* pyspark 머신러닝 라이브러리를 활용하여 예측이 맞는지 확인.\n",
"* 출처 : http://spark.apache.org/docs/latest/api/python/pyspark.ml.html"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"prediction1result areaUnderROC=0.500000\n",
"prediction2result areaUnderROC=0.766873\n",
"prediction1result areaUnderPR=0.383838\n",
"prediction2result areaUnderPR=0.684957\n"
]
}
],
"source": [
"from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
"\n",
"evaluator = BinaryClassificationEvaluator()\n",
"evaluator.setRawPredictionCol(\"prediction\").setLabelCol(\"label\")\n",
"\n",
"evaluator.setMetricName(\"areaUnderROC\")\n",
"print(\"prediction1result areaUnderROC=%f\" % evaluator.evaluate(prediction1result))\n",
"print(\"prediction2result areaUnderROC=%f\" % evaluator.evaluate(prediction2result))\n",
"\n",
"evaluator.setMetricName(\"areaUnderPR\")\n",
"print(\"prediction1result areaUnderPR=%f\" % evaluator.evaluate(prediction1result))\n",
"print(\"prediction2result areaUnderPR=%f\" % evaluator.evaluate(prediction2result))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.classification import *\n",
"lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.feature import *\n",
"assembler = VectorAssembler().setInputCols([\"Pclass\", \"SibSp\"]).setOutputCol(\"features\")\n",
"data2 = assembler.transform(titanic)\n",
"lrModel = lr.fit(data2)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n",
"|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|label| features|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n",
"| 1| 0.0| 3.0|Braund, Mr. Owen ...| male| 22| 1.0| 0.0| A/5 21171| 7.25| null| S| 0.0|[3.0,1.0]|\n",
"| 2| 1.0| 1.0|Cumings, Mrs. Joh...|female| 38| 1.0| 0.0| PC 17599|71.2833| C85| C| 1.0|[1.0,1.0]|\n",
"| 3| 1.0| 3.0|Heikkinen, Miss. ...|female| 26| 0.0| 0.0|STON/O2. 3101282| 7.925| null| S| 1.0|[3.0,0.0]|\n",
"| 4| 1.0| 1.0|Futrelle, Mrs. Ja...|female| 35| 1.0| 0.0| 113803| 53.1| C123| S| 1.0|[1.0,1.0]|\n",
"| 5| 0.0| 3.0|Allen, Mr. Willia...| male| 35| 0.0| 0.0| 373450| 8.05| null| S| 0.0|[3.0,0.0]|\n",
"| 6| 0.0| 3.0| Moran, Mr. James| male|null| 0.0| 0.0| 330877| 8.4583| null| Q| 0.0|[3.0,0.0]|\n",
"| 7| 0.0| 1.0|McCarthy, Mr. Tim...| male| 54| 0.0| 0.0| 17463|51.8625| E46| S| 0.0|[1.0,0.0]|\n",
"| 8| 0.0| 3.0|Palsson, Master. ...| male| 2| 3.0| 1.0| 349909| 21.075| null| S| 0.0|[3.0,3.0]|\n",
"| 9| 1.0| 3.0|Johnson, Mrs. Osc...|female| 27| 0.0| 2.0| 347742|11.1333| null| S| 1.0|[3.0,0.0]|\n",
"| 10| 1.0| 2.0|Nasser, Mrs. Nich...|female| 14| 1.0| 0.0| 237736|30.0708| null| C| 1.0|[2.0,1.0]|\n",
"| 11| 1.0| 3.0|Sandstrom, Miss. ...|female| 4| 1.0| 1.0| PP 9549| 16.7| G6| S| 1.0|[3.0,1.0]|\n",
"| 12| 1.0| 1.0|Bonnell, Miss. El...|female| 58| 0.0| 0.0| 113783| 26.55| C103| S| 1.0|[1.0,0.0]|\n",
"| 13| 0.0| 3.0|Saundercock, Mr. ...| male| 20| 0.0| 0.0| A/5. 2151| 8.05| null| S| 0.0|[3.0,0.0]|\n",
"| 14| 0.0| 3.0|Andersson, Mr. An...| male| 39| 1.0| 5.0| 347082| 31.275| null| S| 0.0|[3.0,1.0]|\n",
"| 15| 0.0| 3.0|Vestrom, Miss. Hu...|female| 14| 0.0| 0.0| 350406| 7.8542| null| S| 0.0|[3.0,0.0]|\n",
"| 16| 1.0| 2.0|Hewlett, Mrs. (Ma...|female| 55| 0.0| 0.0| 248706| 16| null| S| 1.0|[2.0,0.0]|\n",
"| 17| 0.0| 3.0|Rice, Master. Eugene| male| 2| 4.0| 1.0| 382652| 29.125| null| Q| 0.0|[3.0,4.0]|\n",
"| 18| 1.0| 2.0|Williams, Mr. Cha...| male|null| 0.0| 0.0| 244373| 13| null| S| 1.0|[2.0,0.0]|\n",
"| 19| 0.0| 3.0|Vander Planke, Mr...|female| 31| 1.0| 0.0| 345763| 18| null| S| 0.0|[3.0,1.0]|\n",
"| 20| 1.0| 3.0|Masselmani, Mrs. ...|female|null| 0.0| 0.0| 2649| 7.225| null| C| 1.0|[3.0,0.0]|\n",
"+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"data2.show()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import IFrame\n",
"IFrame('https://www.zepl.com/viewer/notebooks/bm90ZTovL1NEUkx1cmtlci8wMDM2MGM2ZWQzZWM0NjQyYjdlMTk0YzhlZmVmMDNjOC9ub3RlLmpzb24', width='100%', height=600)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}