{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from pyspark.sql import SparkSession\n", "import pandas as pd\n", "\n", "import sys\n", "sys.path.append('..')\n", "from utils.pysparkutils import *\n", "\n", "spark = SparkSession.builder.appName(\"titanic\").getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Survived: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Cabin: string (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n", "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Cabin: string (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n" ] } ], "source": [ "train = spark.read.csv('./train.csv', header=\"true\", inferSchema=\"true\")\n", "test = spark.read.csv('./test.csv', header=\"true\", inferSchema=\"true\")\n", "\n", "train.printSchema()\n", "test.printSchema()\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NoneS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NoneS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NoneS
5603Moran, Mr. JamesmaleNaN003308778.4583NoneQ
6701McCarthy, Mr. Timothy Jmale54.0001746351.8625E46S
7803Palsson, Master. Gosta Leonardmale2.03134990921.0750NoneS
8913Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)female27.00234774211.1333NoneS
91012Nasser, Mrs. Nicholas (Adele Achem)female14.01023773630.0708NoneC
101113Sandstrom, Miss. Marguerite Rutfemale4.011PP 954916.7000G6S
111211Bonnell, Miss. Elizabethfemale58.00011378326.5500C103S
121303Saundercock, Mr. William Henrymale20.000A/5. 21518.0500NoneS
131403Andersson, Mr. Anders Johanmale39.01534708231.2750NoneS
141503Vestrom, Miss. Hulda Amanda Adolfinafemale14.0003504067.8542NoneS
151612Hewlett, Mrs. (Mary D Kingcome)female55.00024870616.0000NoneS
161703Rice, Master. Eugenemale2.04138265229.1250NoneQ
171812Williams, Mr. Charles EugenemaleNaN0024437313.0000NoneS
181903Vander Planke, Mrs. Julius (Emelia Maria Vande...female31.01034576318.0000NoneS
192013Masselmani, Mrs. FatimafemaleNaN0026497.2250NoneC
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "5 6 0 3 \n", "6 7 0 1 \n", "7 8 0 3 \n", "8 9 1 3 \n", "9 10 1 2 \n", "10 11 1 3 \n", "11 12 1 1 \n", "12 13 0 3 \n", "13 14 0 3 \n", "14 15 0 3 \n", "15 16 1 2 \n", "16 17 0 3 \n", "17 18 1 2 \n", "18 19 0 3 \n", "19 20 1 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "5 Moran, Mr. James male NaN 0 \n", "6 McCarthy, Mr. Timothy J male 54.0 0 \n", "7 Palsson, Master. Gosta Leonard male 2.0 3 \n", "8 Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) female 27.0 0 \n", "9 Nasser, Mrs. Nicholas (Adele Achem) female 14.0 1 \n", "10 Sandstrom, Miss. Marguerite Rut female 4.0 1 \n", "11 Bonnell, Miss. Elizabeth female 58.0 0 \n", "12 Saundercock, Mr. William Henry male 20.0 0 \n", "13 Andersson, Mr. Anders Johan male 39.0 1 \n", "14 Vestrom, Miss. Hulda Amanda Adolfina female 14.0 0 \n", "15 Hewlett, Mrs. (Mary D Kingcome) female 55.0 0 \n", "16 Rice, Master. Eugene male 2.0 4 \n", "17 Williams, Mr. Charles Eugene male NaN 0 \n", "18 Vander Planke, Mrs. Julius (Emelia Maria Vande... female 31.0 1 \n", "19 Masselmani, Mrs. Fatima female NaN 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 None S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 None S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 None S \n", "5 0 330877 8.4583 None Q \n", "6 0 17463 51.8625 E46 S \n", "7 1 349909 21.0750 None S \n", "8 2 347742 11.1333 None S \n", "9 0 237736 30.0708 None C \n", "10 1 PP 9549 16.7000 G6 S \n", "11 0 113783 26.5500 C103 S \n", "12 0 A/5. 2151 8.0500 None S \n", "13 5 347082 31.2750 None S \n", "14 0 350406 7.8542 None S \n", "15 0 248706 16.0000 None S \n", "16 1 382652 29.1250 None Q \n", "17 0 244373 13.0000 None S \n", "18 0 345763 18.0000 None S \n", "19 0 2649 7.2250 None C " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.limit(20).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section we will explore missing data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Age', 0.19865319865319866),\n", " ('Cabin', 0.7710437710437711),\n", " ('Embarked', 0.002244668911335578)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "findMissingValuesCols(train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see almost 80% of Cabin column is missing data. So we will drop the Cabin column.\n", "Very few data is missing in Embarked column. We will just drop those rows." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import Imputer\n", "ageImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Survived: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n" ] }, { "data": { "text/plain": [ "889" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train = train.filter(train.Embarked.isNotNull())\n", "train = train.drop('Cabin')\n", "train.printSchema()\n", "train.count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exploratory Data Analysis\n", "In next few sections, we will explore training data and the relationship between different features and labels.\n", "As we already know, most of passengers in Titanic didn't survive. Our training data suggests the same, around one-third of the passengers survived. Same goes for passenger class and sex." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survivedcount
01340
10549
\n", "
" ], "text/plain": [ " Survived count\n", "0 1 340\n", "1 0 549" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labelCol = 'Survived'\n", "train.groupby(labelCol).count().toPandas()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survived_Sexfemalemale
01231109
1081468
\n", "
" ], "text/plain": [ " Survived_Sex female male\n", "0 1 231 109\n", "1 0 81 468" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Sex').toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pointwise Mutual Information (PMI) is useful metric for exploring the relationship between two categorical features. PMI gives a scalar value for a pair of values in features. The value denotes the amount of information can be derived about other categori\n", "\n", "PMI can be normalized between [-1,+1], is called Normalized PMI, resulting in: \n", "* -1 (in the limit) for never occurring together,\n", "* 0 for independence,\n", "* +1 for complete co-occurrence\n", "\n", "\n", "`calcNormalizedPointwiseMutualInformation` function is implemented in `pysparkutils.py` file in `utils` directory." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SexSurvivedNormalized PMI
2female0-0.361721
1female10.490151
0male00.424895
3male1-0.336078
\n", "
" ], "text/plain": [ " Sex Survived Normalized PMI\n", "2 female 0 -0.361721\n", "1 female 1 0.490151\n", "0 male 0 0.424895\n", "3 male 1 -0.336078" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Sex', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Sex', labelCol)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survived_Pclass123
0113487119
108097372
\n", "
" ], "text/plain": [ " Survived_Pclass 1 2 3\n", "0 1 134 87 119\n", "1 0 80 97 372" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Pclass').toPandas()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PclassSurvivedNormalized PMI
010-0.208445
2110.260544
420-0.071421
3210.091268
5300.234674
131-0.226840
\n", "
" ], "text/plain": [ " Pclass Survived Normalized PMI\n", "0 1 0 -0.208445\n", "2 1 1 0.260544\n", "4 2 0 -0.071421\n", "3 2 1 0.091268\n", "5 3 0 0.234674\n", "1 3 1 -0.226840" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Pclass', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Pclass', labelCol)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survived_EmbarkedCQS
019330217
107547427
\n", "
" ], "text/plain": [ " Survived_Embarked C Q S\n", "0 1 93 30 217\n", "1 0 75 47 427" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Embarked').toPandas()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EmbarkedSurvivedNormalized PMI
5C0-0.131229
3C10.163804
4Q0-0.003966
0Q10.005472
1S00.096935
2S1-0.089810
\n", "
" ], "text/plain": [ " Embarked Survived Normalized PMI\n", "5 C 0 -0.131229\n", "3 C 1 0.163804\n", "4 Q 0 -0.003966\n", "0 Q 1 0.005472\n", "1 S 0 0.096935\n", "2 S 1 -0.089810" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Embarked', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Embarked', labelCol)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survived_SibSp0123458
01208112134300
103989715121557
\n", "
" ], "text/plain": [ " Survived_SibSp 0 1 2 3 4 5 8\n", "0 1 208 112 13 4 3 0 0\n", "1 0 398 97 15 12 15 5 7" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'SibSp').toPandas()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SibSpSurvivedNormalized PMI
4000.076614
601-0.074483
010-0.128928
3110.162829
720-0.034825
5210.045891
10300.045135
131-0.078675
2400.073413
1141-0.145939
8500.093038
9800.099500
\n", "
" ], "text/plain": [ " SibSp Survived Normalized PMI\n", "4 0 0 0.076614\n", "6 0 1 -0.074483\n", "0 1 0 -0.128928\n", "3 1 1 0.162829\n", "7 2 0 -0.034825\n", "5 2 1 0.045891\n", "10 3 0 0.045135\n", "1 3 1 -0.078675\n", "2 4 0 0.073413\n", "11 4 1 -0.145939\n", "8 5 0 0.093038\n", "9 8 0 0.099500" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'SibSp', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'SibSp', labelCol)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Survived_Parch0123456
0123165403010
1044553402441
\n", "
" ], "text/plain": [ " Survived_Parch 0 1 2 3 4 5 6\n", "0 1 231 65 40 3 0 1 0\n", "1 0 445 53 40 2 4 4 1" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Parch').toPandas()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ParchSurvivedNormalized PMI
5000.092309
701-0.083569
010-0.112913
4110.139486
820-0.068086
6210.086419
1130-0.071231
1310.079123
3400.089196
9500.047902
1051-0.095475
2600.070986
\n", "
" ], "text/plain": [ " Parch Survived Normalized PMI\n", "5 0 0 0.092309\n", "7 0 1 -0.083569\n", "0 1 0 -0.112913\n", "4 1 1 0.139486\n", "8 2 0 -0.068086\n", "6 2 1 0.086419\n", "11 3 0 -0.071231\n", "1 3 1 0.079123\n", "3 4 0 0.089196\n", "9 5 0 0.047902\n", "10 5 1 -0.095475\n", "2 6 0 0.070986" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Parch', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Parch', labelCol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will calculate the entropy of categorical features, which will give us the variance for categorical features." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FeatureEntropy
0Sex0.934919
1Pclass0.907245
2Embarked0.692048
3SibSp0.477435
4Parch0.402510
\n", "
" ], "text/plain": [ " Feature Entropy\n", "0 Sex 0.934919\n", "1 Pclass 0.907245\n", "2 Embarked 0.692048\n", "3 SibSp 0.477435\n", "4 Parch 0.402510" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "columns = ['Sex', 'Pclass', 'Embarked', 'SibSp', 'Parch']\n", "entropies = calcNormalizedEntropy(train, *columns)\n", "dictToPandasDF(entropies, 'Feature', 'Entropy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical feature independence test via chi square test." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pValues: [0.0,0.0,6.02813466444e-06,4.02603175464e-07,4.93843854699e-11,0.0101897422598,0.0315461645121,4.25298058386e-05,0.00150622342036,0.612884928604,0.116808580457]\n", "degreesOfFreedom: [2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", "statistics: [100.980407261,260.756342249,20.4792462347,25.6818141585,43.2014376768,6.60142083307,4.62299205997,16.755010532,10.0709867693,0.255995235761,2.45959925254]\n" ] } ], "source": [ "from pyspark.ml import Pipeline\n", "from pyspark.ml.stat import ChiSquareTest\n", "from pyspark.ml.feature import Bucketizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler, VectorIndexer\n", "\n", "edaEmbarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked')\n", "edaSexIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex')\n", "\n", "edaAgeImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')\n", "\n", "ageSplits = [0, 16, 32, 48, 64, 200]\n", "edaAgeBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge')\n", "\n", "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n", "edaFareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare')\n", "\n", "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n", " outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n", "inputCols=['Pclass', 'oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge']\n", "edaAssembler = VectorAssembler(inputCols=inputCols, outputCol='features')\n", "\n", "pipeline = Pipeline(stages=[edaEmbarkedIndexer, edaSexIndexer, edaAgeImputer, edaAgeBucketizer, \n", " edaFareBucketizer, oneHotEncoderEstimator, edaAssembler])\n", "chiSqTrain = pipeline.fit(train).transform(train)\n", "\n", "r = ChiSquareTest.test(chiSqTrain, 'features', 'Survived').head()\n", "print(\"pValues: \" + str(r.pValues))\n", "print(\"degreesOfFreedom: \" + str(r.degreesOfFreedom))\n", "print(\"statistics: \" + str(r.statistics))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer\n", "\n", "embarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked', handleInvalid='skip')\n", "sexFeatureIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex', handleInvalid='skip')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import Bucketizer\n", "\n", "ageSplits = [0, 16, 32, 48, 64, 200]\n", "ageBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge', handleInvalid='skip')\n", "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n", "fareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare', handleInvalid='skip')" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, VectorIndexer\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.ml.classification import RandomForestClassifier\n", "\n", "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n", " outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n", "assembler = VectorAssembler(inputCols=['Pclass', 'SibSp', 'Parch', 'bucketedAge', \n", " 'bucketedFare', 'indexedEmbarked', 'indexedSex'], outputCol='features')\n", "rf = RandomForestClassifier(labelCol=labelCol, featuresCol='features')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n", "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "from pyspark.ml import Pipeline\n", "\n", "pipeline = Pipeline(stages=[ageImputer, embarkedIndexer, sexFeatureIndexer, ageBucketizer, \n", " fareBucketizer, oneHotEncoderEstimator, assembler, rf])\n", "\n", "grid = ParamGridBuilder().addGrid(rf.numTrees, [15, 20, 25, 30])\\\n", " .addGrid(rf.maxDepth, [5, 8])\\\n", " .build()\n", "\n", "cv = CrossValidator(estimator=pipeline, \n", " estimatorParamMaps=grid, \n", " evaluator=BinaryClassificationEvaluator(labelCol=labelCol, metricName='areaUnderROC'), \n", " numFolds=10)\n", "\n", "model = cv.fit(train)\n", "train = model.transform(train)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9265509482481523" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluator = model.getEvaluator()\n", "evaluator.evaluate(train)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "test = model.transform(test)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabin...bucketedAgebucketedFareoneHotSexoneHotEmbarkedoneHotFareoneHotAgefeaturesrawPredictionprobabilityprediction
08923Kelly, Mr. Jamesmale34.5003309117.8292None...2.00.0(1.0)(0.0, 0.0)(1.0, 0.0, 0.0)(0.0, 0.0, 1.0, 0.0)(3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0)[29.2960574888, 0.703942511236][0.976535249625, 0.0234647503745]0.0
18933Wilkes, Mrs. James (Ellen Needs)female47.0103632727.0000None...2.00.0(0.0)(1.0, 0.0)(1.0, 0.0, 0.0)(0.0, 0.0, 1.0, 0.0)[3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0][17.9363522365, 12.0636477635][0.597878407882, 0.402121592118]0.0
28942Myles, Mr. Thomas Francismale62.0002402769.6875None...3.01.0(1.0)(0.0, 0.0)(0.0, 1.0, 0.0)(0.0, 0.0, 0.0, 1.0)[2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0][26.2911066277, 3.70889337228][0.876370220924, 0.123629779076]0.0
38953Wirz, Mr. Albertmale27.0003151548.6625None...1.01.0(1.0)(1.0, 0.0)(0.0, 1.0, 0.0)(0.0, 1.0, 0.0, 0.0)(3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0)[25.9086506078, 4.09134939223][0.863621686926, 0.136378313074]0.0
48963Hirvonen, Mrs. Alexander (Helga E Lindqvist)female22.011310129812.2875None...1.01.0(0.0)(1.0, 0.0)(0.0, 1.0, 0.0)(0.0, 1.0, 0.0, 0.0)[3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0][21.2731321988, 8.72686780122][0.709104406626, 0.290895593374]0.0
58973Svensson, Mr. Johan Cervinmale14.00075389.2250None...0.01.0(1.0)(1.0, 0.0)(0.0, 1.0, 0.0)(1.0, 0.0, 0.0, 0.0)(3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0)[21.79357844, 8.20642155996][0.726452614668, 0.273547385332]0.0
68983Connolly, Miss. Katefemale30.0003309727.6292None...1.00.0(0.0)(0.0, 0.0)(1.0, 0.0, 0.0)(0.0, 1.0, 0.0, 0.0)[3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0][6.58546403436, 23.4145359656][0.219515467812, 0.780484532188]1.0
78992Caldwell, Mr. Albert Francismale26.01124873829.0000None...1.02.0(1.0)(1.0, 0.0)(0.0, 0.0, 1.0)(0.0, 1.0, 0.0, 0.0)[2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0][26.9219114219, 3.07808857809][0.897397047397, 0.102602952603]0.0
89003Abrahim, Mrs. Joseph (Sophie Halaut Easu)female18.00026577.2292None...1.00.0(0.0)(0.0, 1.0)(1.0, 0.0, 0.0)(0.0, 1.0, 0.0, 0.0)[3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0][5.95376028838, 24.0462397116][0.198458676279, 0.801541323721]1.0
99013Davies, Mr. John Samuelmale21.020A/4 4887124.1500None...1.02.0(1.0)(1.0, 0.0)(0.0, 0.0, 1.0)(0.0, 1.0, 0.0, 0.0)[3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0][27.683257006, 2.31674299405][0.922775233532, 0.0772247664683]0.0
109023Ilieff, Mr. YliomaleNaN003492207.8958None...1.00.0(1.0)(1.0, 0.0)(1.0, 0.0, 0.0)(0.0, 1.0, 0.0, 0.0)(3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0)[27.1796680333, 2.8203319667][0.905988934443, 0.0940110655565]0.0
119031Jones, Mr. Charles Cressonmale46.00069426.0000None...2.02.0(1.0)(1.0, 0.0)(0.0, 0.0, 1.0)(0.0, 0.0, 1.0, 0.0)(1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0)[14.2653782322, 15.7346217678][0.475512607739, 0.524487392261]1.0
129041Snyder, Mrs. John Pillsbury (Nelle Stevenson)female23.0102122882.2667B45...1.03.0(0.0)(1.0, 0.0)(0.0, 0.0, 0.0)(0.0, 1.0, 0.0, 0.0)[1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0][0.295454545455, 29.7045454545][0.00984848484848, 0.990151515152]1.0
139052Howard, Mr. Benjaminmale63.0102406526.0000None...3.02.0(1.0)(1.0, 0.0)(0.0, 0.0, 1.0)(0.0, 0.0, 0.0, 1.0)[2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0][27.5151896549, 2.4848103451][0.917172988497, 0.0828270115032]0.0
149061Chaffee, Mrs. Herbert Fuller (Carrie Constance...female47.010W.E.P. 573461.1750E31...2.03.0(0.0)(1.0, 0.0)(0.0, 0.0, 0.0)(0.0, 0.0, 1.0, 0.0)[1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0][0.0, 30.0][0.0, 1.0]1.0
159072del Carlo, Mrs. Sebastiano (Argenia Genovesi)female24.010SC/PARIS 216727.7208None...1.02.0(0.0)(0.0, 1.0)(0.0, 0.0, 1.0)(0.0, 1.0, 0.0, 0.0)[2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0][2.02801517479, 27.9719848252][0.0676005058263, 0.932399494174]1.0
169082Keane, Mr. Danielmale35.00023373412.3500None...2.01.0(1.0)(0.0, 0.0)(0.0, 1.0, 0.0)(0.0, 0.0, 1.0, 0.0)[2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0][26.5861882672, 3.41381173276][0.886206275575, 0.113793724425]0.0
179093Assaf, Mr. Geriosmale21.00026927.2250None...1.00.0(1.0)(0.0, 1.0)(1.0, 0.0, 0.0)(0.0, 1.0, 0.0, 0.0)(3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0)[25.8019873912, 4.19801260875][0.860066246375, 0.139933753625]0.0
189103Ilmakangas, Miss. Ida Livijafemale27.010STON/O2. 31012707.9250None...1.01.0(0.0)(1.0, 0.0)(0.0, 1.0, 0.0)(0.0, 1.0, 0.0, 0.0)[3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0][20.8412096143, 9.15879038571][0.694706987143, 0.305293012857]0.0
199113\"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\"female45.00026967.2250None...2.00.0(0.0)(0.0, 1.0)(1.0, 0.0, 0.0)(0.0, 0.0, 1.0, 0.0)[3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0][17.3386478658, 12.6613521342][0.577954928861, 0.422045071139]0.0
\n", "

20 rows × 24 columns

\n", "
" ], "text/plain": [ " PassengerId Pclass Name \\\n", "0 892 3 Kelly, Mr. James \n", "1 893 3 Wilkes, Mrs. James (Ellen Needs) \n", "2 894 2 Myles, Mr. Thomas Francis \n", "3 895 3 Wirz, Mr. Albert \n", "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) \n", "5 897 3 Svensson, Mr. Johan Cervin \n", "6 898 3 Connolly, Miss. Kate \n", "7 899 2 Caldwell, Mr. Albert Francis \n", "8 900 3 Abrahim, Mrs. Joseph (Sophie Halaut Easu) \n", "9 901 3 Davies, Mr. John Samuel \n", "10 902 3 Ilieff, Mr. Ylio \n", "11 903 1 Jones, Mr. Charles Cresson \n", "12 904 1 Snyder, Mrs. John Pillsbury (Nelle Stevenson) \n", "13 905 2 Howard, Mr. Benjamin \n", "14 906 1 Chaffee, Mrs. Herbert Fuller (Carrie Constance... \n", "15 907 2 del Carlo, Mrs. Sebastiano (Argenia Genovesi) \n", "16 908 2 Keane, Mr. Daniel \n", "17 909 3 Assaf, Mr. Gerios \n", "18 910 3 Ilmakangas, Miss. Ida Livija \n", "19 911 3 \"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\" \n", "\n", " Sex Age SibSp Parch Ticket Fare Cabin ... \\\n", "0 male 34.5 0 0 330911 7.8292 None ... \n", "1 female 47.0 1 0 363272 7.0000 None ... \n", "2 male 62.0 0 0 240276 9.6875 None ... \n", "3 male 27.0 0 0 315154 8.6625 None ... \n", "4 female 22.0 1 1 3101298 12.2875 None ... \n", "5 male 14.0 0 0 7538 9.2250 None ... \n", "6 female 30.0 0 0 330972 7.6292 None ... \n", "7 male 26.0 1 1 248738 29.0000 None ... \n", "8 female 18.0 0 0 2657 7.2292 None ... \n", "9 male 21.0 2 0 A/4 48871 24.1500 None ... \n", "10 male NaN 0 0 349220 7.8958 None ... \n", "11 male 46.0 0 0 694 26.0000 None ... \n", "12 female 23.0 1 0 21228 82.2667 B45 ... \n", "13 male 63.0 1 0 24065 26.0000 None ... \n", "14 female 47.0 1 0 W.E.P. 5734 61.1750 E31 ... \n", "15 female 24.0 1 0 SC/PARIS 2167 27.7208 None ... \n", "16 male 35.0 0 0 233734 12.3500 None ... \n", "17 male 21.0 0 0 2692 7.2250 None ... \n", "18 female 27.0 1 0 STON/O2. 3101270 7.9250 None ... \n", "19 female 45.0 0 0 2696 7.2250 None ... \n", "\n", " bucketedAge bucketedFare oneHotSex oneHotEmbarked oneHotFare \\\n", "0 2.0 0.0 (1.0) (0.0, 0.0) (1.0, 0.0, 0.0) \n", "1 2.0 0.0 (0.0) (1.0, 0.0) (1.0, 0.0, 0.0) \n", "2 3.0 1.0 (1.0) (0.0, 0.0) (0.0, 1.0, 0.0) \n", "3 1.0 1.0 (1.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "4 1.0 1.0 (0.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "5 0.0 1.0 (1.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "6 1.0 0.0 (0.0) (0.0, 0.0) (1.0, 0.0, 0.0) \n", "7 1.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "8 1.0 0.0 (0.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "9 1.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "10 1.0 0.0 (1.0) (1.0, 0.0) (1.0, 0.0, 0.0) \n", "11 2.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "12 1.0 3.0 (0.0) (1.0, 0.0) (0.0, 0.0, 0.0) \n", "13 3.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "14 2.0 3.0 (0.0) (1.0, 0.0) (0.0, 0.0, 0.0) \n", "15 1.0 2.0 (0.0) (0.0, 1.0) (0.0, 0.0, 1.0) \n", "16 2.0 1.0 (1.0) (0.0, 0.0) (0.0, 1.0, 0.0) \n", "17 1.0 0.0 (1.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "18 1.0 1.0 (0.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "19 2.0 0.0 (0.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "\n", " oneHotAge features \\\n", "0 (0.0, 0.0, 1.0, 0.0) (3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0) \n", "1 (0.0, 0.0, 1.0, 0.0) [3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0] \n", "2 (0.0, 0.0, 0.0, 1.0) [2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0] \n", "3 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0) \n", "4 (0.0, 1.0, 0.0, 0.0) [3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0] \n", "5 (1.0, 0.0, 0.0, 0.0) (3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0) \n", "6 (0.0, 1.0, 0.0, 0.0) [3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0] \n", "7 (0.0, 1.0, 0.0, 0.0) [2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0] \n", "8 (0.0, 1.0, 0.0, 0.0) [3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0] \n", "9 (0.0, 1.0, 0.0, 0.0) [3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0] \n", "10 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0) \n", "11 (0.0, 0.0, 1.0, 0.0) (1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0) \n", "12 (0.0, 1.0, 0.0, 0.0) [1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0] \n", "13 (0.0, 0.0, 0.0, 1.0) [2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0] \n", "14 (0.0, 0.0, 1.0, 0.0) [1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0] \n", "15 (0.0, 1.0, 0.0, 0.0) [2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0] \n", "16 (0.0, 0.0, 1.0, 0.0) [2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0] \n", "17 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0) \n", "18 (0.0, 1.0, 0.0, 0.0) [3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0] \n", "19 (0.0, 0.0, 1.0, 0.0) [3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0] \n", "\n", " rawPrediction probability \\\n", "0 [29.2960574888, 0.703942511236] [0.976535249625, 0.0234647503745] \n", "1 [17.9363522365, 12.0636477635] [0.597878407882, 0.402121592118] \n", "2 [26.2911066277, 3.70889337228] [0.876370220924, 0.123629779076] \n", "3 [25.9086506078, 4.09134939223] [0.863621686926, 0.136378313074] \n", "4 [21.2731321988, 8.72686780122] [0.709104406626, 0.290895593374] \n", "5 [21.79357844, 8.20642155996] [0.726452614668, 0.273547385332] \n", "6 [6.58546403436, 23.4145359656] [0.219515467812, 0.780484532188] \n", "7 [26.9219114219, 3.07808857809] [0.897397047397, 0.102602952603] \n", "8 [5.95376028838, 24.0462397116] [0.198458676279, 0.801541323721] \n", "9 [27.683257006, 2.31674299405] [0.922775233532, 0.0772247664683] \n", "10 [27.1796680333, 2.8203319667] [0.905988934443, 0.0940110655565] \n", "11 [14.2653782322, 15.7346217678] [0.475512607739, 0.524487392261] \n", "12 [0.295454545455, 29.7045454545] [0.00984848484848, 0.990151515152] \n", "13 [27.5151896549, 2.4848103451] [0.917172988497, 0.0828270115032] \n", "14 [0.0, 30.0] [0.0, 1.0] \n", "15 [2.02801517479, 27.9719848252] [0.0676005058263, 0.932399494174] \n", "16 [26.5861882672, 3.41381173276] [0.886206275575, 0.113793724425] \n", "17 [25.8019873912, 4.19801260875] [0.860066246375, 0.139933753625] \n", "18 [20.8412096143, 9.15879038571] [0.694706987143, 0.305293012857] \n", "19 [17.3386478658, 12.6613521342] [0.577954928861, 0.422045071139] \n", "\n", " prediction \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "5 0.0 \n", "6 1.0 \n", "7 0.0 \n", "8 1.0 \n", "9 0.0 \n", "10 0.0 \n", "11 1.0 \n", "12 1.0 \n", "13 0.0 \n", "14 1.0 \n", "15 1.0 \n", "16 0.0 \n", "17 0.0 \n", "18 0.0 \n", "19 0.0 \n", "\n", "[20 rows x 24 columns]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test.limit(20).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Write the predictions to CSV file in Kaggle specified format." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import IntegerType\n", "\n", "csvPath = 'prediction.csv'\n", "test.select('PassengerId', 'prediction')\\\n", " .coalesce(1)\\\n", " .withColumn('Survived', test['prediction'].cast(IntegerType()))\\\n", " .drop('prediction')\\\n", " .write.csv(csvPath, header='true', mode='ignore')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }