{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Titanic\n", "\n", "학습목표 : Titanic의 탑승자 정보를 통해 생존자를 예측하는 모델 만들기" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import findspark\n", "findspark.init()\n", "import pyspark\n", "sc = pyspark.SparkContext()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- PassengerId: string (nullable = true)\n", " |-- Survived: double (nullable = true)\n", " |-- Pclass: double (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: string (nullable = true)\n", " |-- SibSp: double (nullable = true)\n", " |-- Parch: double (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: string (nullable = true)\n", " |-- Cabin: string (nullable = true)\n", " |-- Embarked: string (nullable = false)\n", " |-- label: double (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.sql.session import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "\n", "spark = SparkSession(sc)\n", "titanic = spark.read.option(\"header\", \"true\").csv(\"/Users/ryanshin/Downloads/train.csv\") \\\n", " .withColumn(\"Survived\", col(\"Survived\").cast(\"double\")) \\\n", " .withColumn(\"label\", col(\"Survived\")) \\\n", " .withColumn(\"Pclass\", col(\"Pclass\").cast(\"double\"))\\\n", " .withColumn(\"SibSp\", col(\"SibSp\").cast(\"double\"))\\\n", " .withColumn(\"Parch\", col(\"Parch\").cast(\"double\"))\\\n", " .na.fill(\"S\", \"Embarked\")\n", "titanic.printSchema()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n", "|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|label|\n", "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n", "| 1| 0.0| 3.0|Braund, Mr. Owen ...| male| 22| 1.0| 0.0| A/5 21171| 7.25| null| S| 0.0|\n", "| 2| 1.0| 1.0|Cumings, Mrs. Joh...|female| 38| 1.0| 0.0| PC 17599|71.2833| C85| C| 1.0|\n", "| 3| 1.0| 3.0|Heikkinen, Miss. ...|female| 26| 0.0| 0.0|STON/O2. 3101282| 7.925| null| S| 1.0|\n", "| 4| 1.0| 1.0|Futrelle, Mrs. Ja...|female| 35| 1.0| 0.0| 113803| 53.1| C123| S| 1.0|\n", "| 5| 0.0| 3.0|Allen, Mr. Willia...| male| 35| 0.0| 0.0| 373450| 8.05| null| S| 0.0|\n", "| 6| 0.0| 3.0| Moran, Mr. James| male|null| 0.0| 0.0| 330877| 8.4583| null| Q| 0.0|\n", "| 7| 0.0| 1.0|McCarthy, Mr. Tim...| male| 54| 0.0| 0.0| 17463|51.8625| E46| S| 0.0|\n", "| 8| 0.0| 3.0|Palsson, Master. ...| male| 2| 3.0| 1.0| 349909| 21.075| null| S| 0.0|\n", "| 9| 1.0| 3.0|Johnson, Mrs. Osc...|female| 27| 0.0| 2.0| 347742|11.1333| null| S| 1.0|\n", "| 10| 1.0| 2.0|Nasser, Mrs. Nich...|female| 14| 1.0| 0.0| 237736|30.0708| null| C| 1.0|\n", "| 11| 1.0| 3.0|Sandstrom, Miss. ...|female| 4| 1.0| 1.0| PP 9549| 16.7| G6| S| 1.0|\n", "| 12| 1.0| 1.0|Bonnell, Miss. El...|female| 58| 0.0| 0.0| 113783| 26.55| C103| S| 1.0|\n", "| 13| 0.0| 3.0|Saundercock, Mr. ...| male| 20| 0.0| 0.0| A/5. 2151| 8.05| null| S| 0.0|\n", "| 14| 0.0| 3.0|Andersson, Mr. An...| male| 39| 1.0| 5.0| 347082| 31.275| null| S| 0.0|\n", "| 15| 0.0| 3.0|Vestrom, Miss. Hu...|female| 14| 0.0| 0.0| 350406| 7.8542| null| S| 0.0|\n", "| 16| 1.0| 2.0|Hewlett, Mrs. (Ma...|female| 55| 0.0| 0.0| 248706| 16| null| S| 1.0|\n", "| 17| 0.0| 3.0|Rice, Master. Eugene| male| 2| 4.0| 1.0| 382652| 29.125| null| Q| 0.0|\n", "| 18| 1.0| 2.0|Williams, Mr. Cha...| male|null| 0.0| 0.0| 244373| 13| null| S| 1.0|\n", "| 19| 0.0| 3.0|Vander Planke, Mr...|female| 31| 1.0| 0.0| 345763| 18| null| S| 0.0|\n", "| 20| 1.0| 3.0|Masselmani, Mrs. ...|female|null| 0.0| 0.0| 2649| 7.225| null| C| 1.0|\n", "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "titanic.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 데이터 타입\n", "* 출처 : https://www.kaggle.com/c/titanic/data\n", "* Survived : 살았으면 1, 죽었으면 0\n", "* SibSp : 형제자매나 배우자가 몇명 있는지?\n", "* Parch : 자식이 몇명인지?" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------------+-------------+------------------------------------+\n", "|count(PassengerId)|sum(Survived)|(sum(Survived) / count(PassengerId))|\n", "+------------------+-------------+------------------------------------+\n", "| 891| 342.0| 0.3838383838383838|\n", "+------------------+-------------+------------------------------------+\n", "\n" ] } ], "source": [ "titanic.select(count(\"PassengerId\"), sum(\"Survived\"), sum(\"Survived\")/count(\"PassengerId\")).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 생존확율\n", "* 생존자수( sum(\"Survived\") ) /전체승객수( count(\"PassengerId\") ) " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+--------+-----+\n", "|Survived|count|\n", "+--------+-----+\n", "| 0.0| 549|\n", "| 1.0| 342|\n", "+--------+-----+\n", "\n" ] } ], "source": [ "titanic.groupBy(\"Survived\").count().show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------+--------+-----+\n", "|Pclass|Survived|count|\n", "+------+--------+-----+\n", "| 1.0| 0.0| 80|\n", "| 1.0| 1.0| 136|\n", "| 2.0| 0.0| 97|\n", "| 2.0| 1.0| 87|\n", "| 3.0| 0.0| 372|\n", "| 3.0| 1.0| 119|\n", "+------+--------+-----+\n", "\n" ] } ], "source": [ "titanic.groupBy(\"Pclass\", \"Survived\").count().orderBy(\"Pclass\", \"Survived\").show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------+--------+-----+\n", "| Sex|Survived|count|\n", "+------+--------+-----+\n", "|female| 0.0| 81|\n", "|female| 1.0| 233|\n", "| male| 0.0| 468|\n", "| male| 1.0| 109|\n", "+------+--------+-----+\n", "\n" ] } ], "source": [ "titanic.groupBy(\"Sex\", \"Survived\").count().orderBy(\"Sex\", \"Survived\").show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------+-----+\n", "|SibSp|Survived|count|\n", "+-----+--------+-----+\n", "| 0.0| 0.0| 398|\n", "| 0.0| 1.0| 210|\n", "| 1.0| 0.0| 97|\n", "| 1.0| 1.0| 112|\n", "| 2.0| 0.0| 15|\n", "| 2.0| 1.0| 13|\n", "| 3.0| 0.0| 12|\n", "| 3.0| 1.0| 4|\n", "| 4.0| 0.0| 15|\n", "| 4.0| 1.0| 3|\n", "| 5.0| 0.0| 5|\n", "| 8.0| 0.0| 7|\n", "+-----+--------+-----+\n", "\n" ] } ], "source": [ "titanic.groupBy(\"SibSp\", \"Survived\").count().orderBy(\"SibSp\", \"Survived\").show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----+--------+-----+\n", "|Parch|Survived|count|\n", "+-----+--------+-----+\n", "| 0.0| 0.0| 445|\n", "| 0.0| 1.0| 233|\n", "| 1.0| 0.0| 53|\n", "| 1.0| 1.0| 65|\n", "| 2.0| 0.0| 40|\n", "| 2.0| 1.0| 40|\n", "| 3.0| 0.0| 2|\n", "| 3.0| 1.0| 3|\n", "| 4.0| 0.0| 4|\n", "| 5.0| 0.0| 4|\n", "| 5.0| 1.0| 1|\n", "| 6.0| 0.0| 1|\n", "+-----+--------+-----+\n", "\n" ] } ], "source": [ "titanic.groupBy(\"Parch\", \"Survived\").count().orderBy(\"Parch\", \"Survived\").show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 다 죽었다고 예측\n", "def predict1_func():\n", " return 0.0\n", "predict1 = udf(predict1_func, returnType=DoubleType())\n", " \n", "# 여자는 다 살았다고 남자는 다 죽었다고 예측\n", "def predict2_func(gender):\n", " if gender == \"female\":\n", " return 1.0\n", " else:\n", " return 0.0 \n", "predict2 = udf(predict2_func, returnType=DoubleType())\n", " \n", "# UDF 생성\n", "prediction1result = titanic.select(predict1().alias(\"prediction\"), col(\"Survived\").cast(\"double\").alias(\"label\"))\n", "prediction2result = titanic.select(predict2(\"Sex\").alias(\"prediction\"), col(\"Survived\").cast(\"double\").alias(\"label\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* pyspark 머신러닝 라이브러리를 활용하여 예측이 맞는지 확인.\n", "* 출처 : http://spark.apache.org/docs/latest/api/python/pyspark.ml.html" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "prediction1result areaUnderROC=0.500000\n", "prediction2result areaUnderROC=0.766873\n", "prediction1result areaUnderPR=0.383838\n", "prediction2result areaUnderPR=0.684957\n" ] } ], "source": [ "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "\n", "evaluator = BinaryClassificationEvaluator()\n", "evaluator.setRawPredictionCol(\"prediction\").setLabelCol(\"label\")\n", "\n", "evaluator.setMetricName(\"areaUnderROC\")\n", "print(\"prediction1result areaUnderROC=%f\" % evaluator.evaluate(prediction1result))\n", "print(\"prediction2result areaUnderROC=%f\" % evaluator.evaluate(prediction2result))\n", "\n", "evaluator.setMetricName(\"areaUnderPR\")\n", "print(\"prediction1result areaUnderPR=%f\" % evaluator.evaluate(prediction1result))\n", "print(\"prediction2result areaUnderPR=%f\" % evaluator.evaluate(prediction2result))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.classification import *\n", "lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import *\n", "assembler = VectorAssembler().setInputCols([\"Pclass\", \"SibSp\"]).setOutputCol(\"features\")\n", "data2 = assembler.transform(titanic)\n", "lrModel = lr.fit(data2)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n", "|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|label| features|\n", "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n", "| 1| 0.0| 3.0|Braund, Mr. Owen ...| male| 22| 1.0| 0.0| A/5 21171| 7.25| null| S| 0.0|[3.0,1.0]|\n", "| 2| 1.0| 1.0|Cumings, Mrs. Joh...|female| 38| 1.0| 0.0| PC 17599|71.2833| C85| C| 1.0|[1.0,1.0]|\n", "| 3| 1.0| 3.0|Heikkinen, Miss. ...|female| 26| 0.0| 0.0|STON/O2. 3101282| 7.925| null| S| 1.0|[3.0,0.0]|\n", "| 4| 1.0| 1.0|Futrelle, Mrs. Ja...|female| 35| 1.0| 0.0| 113803| 53.1| C123| S| 1.0|[1.0,1.0]|\n", "| 5| 0.0| 3.0|Allen, Mr. Willia...| male| 35| 0.0| 0.0| 373450| 8.05| null| S| 0.0|[3.0,0.0]|\n", "| 6| 0.0| 3.0| Moran, Mr. James| male|null| 0.0| 0.0| 330877| 8.4583| null| Q| 0.0|[3.0,0.0]|\n", "| 7| 0.0| 1.0|McCarthy, Mr. Tim...| male| 54| 0.0| 0.0| 17463|51.8625| E46| S| 0.0|[1.0,0.0]|\n", "| 8| 0.0| 3.0|Palsson, Master. ...| male| 2| 3.0| 1.0| 349909| 21.075| null| S| 0.0|[3.0,3.0]|\n", "| 9| 1.0| 3.0|Johnson, Mrs. Osc...|female| 27| 0.0| 2.0| 347742|11.1333| null| S| 1.0|[3.0,0.0]|\n", "| 10| 1.0| 2.0|Nasser, Mrs. Nich...|female| 14| 1.0| 0.0| 237736|30.0708| null| C| 1.0|[2.0,1.0]|\n", "| 11| 1.0| 3.0|Sandstrom, Miss. ...|female| 4| 1.0| 1.0| PP 9549| 16.7| G6| S| 1.0|[3.0,1.0]|\n", "| 12| 1.0| 1.0|Bonnell, Miss. El...|female| 58| 0.0| 0.0| 113783| 26.55| C103| S| 1.0|[1.0,0.0]|\n", "| 13| 0.0| 3.0|Saundercock, Mr. ...| male| 20| 0.0| 0.0| A/5. 2151| 8.05| null| S| 0.0|[3.0,0.0]|\n", "| 14| 0.0| 3.0|Andersson, Mr. An...| male| 39| 1.0| 5.0| 347082| 31.275| null| S| 0.0|[3.0,1.0]|\n", "| 15| 0.0| 3.0|Vestrom, Miss. Hu...|female| 14| 0.0| 0.0| 350406| 7.8542| null| S| 0.0|[3.0,0.0]|\n", "| 16| 1.0| 2.0|Hewlett, Mrs. (Ma...|female| 55| 0.0| 0.0| 248706| 16| null| S| 1.0|[2.0,0.0]|\n", "| 17| 0.0| 3.0|Rice, Master. Eugene| male| 2| 4.0| 1.0| 382652| 29.125| null| Q| 0.0|[3.0,4.0]|\n", "| 18| 1.0| 2.0|Williams, Mr. Cha...| male|null| 0.0| 0.0| 244373| 13| null| S| 1.0|[2.0,0.0]|\n", "| 19| 0.0| 3.0|Vander Planke, Mr...|female| 31| 1.0| 0.0| 345763| 18| null| S| 0.0|[3.0,1.0]|\n", "| 20| 1.0| 3.0|Masselmani, Mrs. ...|female|null| 0.0| 0.0| 2649| 7.225| null| C| 1.0|[3.0,0.0]|\n", "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----+---------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "data2.show()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import IFrame\n", "IFrame('https://www.zepl.com/viewer/notebooks/bm90ZTovL1NEUkx1cmtlci8wMDM2MGM2ZWQzZWM0NjQyYjdlMTk0YzhlZmVmMDNjOC9ub3RlLmpzb24', width='100%', height=600)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }