{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from pyspark.sql import SparkSession\n", "import pandas as pd\n", "\n", "import sys\n", "sys.path.append('..')\n", "from utils.pysparkutils import *\n", "\n", "spark = SparkSession.builder.appName(\"titanic\").getOrCreate()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Survived: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Cabin: string (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n", "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Cabin: string (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n" ] } ], "source": [ "train = spark.read.csv('./train.csv', header=\"true\", inferSchema=\"true\")\n", "test = spark.read.csv('./test.csv', header=\"true\", inferSchema=\"true\")\n", "\n", "train.printSchema()\n", "test.printSchema()\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>PassengerId</th>\n", " <th>Survived</th>\n", " <th>Pclass</th>\n", " <th>Name</th>\n", " <th>Sex</th>\n", " <th>Age</th>\n", " <th>SibSp</th>\n", " <th>Parch</th>\n", " <th>Ticket</th>\n", " <th>Fare</th>\n", " <th>Cabin</th>\n", " <th>Embarked</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Braund, Mr. Owen Harris</td>\n", " <td>male</td>\n", " <td>22.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>A/5 21171</td>\n", " <td>7.2500</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>2</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", " <td>female</td>\n", " <td>38.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>PC 17599</td>\n", " <td>71.2833</td>\n", " <td>C85</td>\n", " <td>C</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>Heikkinen, Miss. Laina</td>\n", " <td>female</td>\n", " <td>26.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>STON/O2. 3101282</td>\n", " <td>7.9250</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>4</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", " <td>female</td>\n", " <td>35.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>113803</td>\n", " <td>53.1000</td>\n", " <td>C123</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>5</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Allen, Mr. William Henry</td>\n", " <td>male</td>\n", " <td>35.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>373450</td>\n", " <td>8.0500</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>6</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Moran, Mr. James</td>\n", " <td>male</td>\n", " <td>NaN</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>330877</td>\n", " <td>8.4583</td>\n", " <td>None</td>\n", " <td>Q</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>7</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>McCarthy, Mr. Timothy J</td>\n", " <td>male</td>\n", " <td>54.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>17463</td>\n", " <td>51.8625</td>\n", " <td>E46</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>8</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Palsson, Master. Gosta Leonard</td>\n", " <td>male</td>\n", " <td>2.0</td>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>349909</td>\n", " <td>21.0750</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>9</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n", " <td>female</td>\n", " <td>27.0</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>347742</td>\n", " <td>11.1333</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>10</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n", " <td>female</td>\n", " <td>14.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>237736</td>\n", " <td>30.0708</td>\n", " <td>None</td>\n", " <td>C</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>11</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>Sandstrom, Miss. Marguerite Rut</td>\n", " <td>female</td>\n", " <td>4.0</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>PP 9549</td>\n", " <td>16.7000</td>\n", " <td>G6</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>12</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>Bonnell, Miss. Elizabeth</td>\n", " <td>female</td>\n", " <td>58.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>113783</td>\n", " <td>26.5500</td>\n", " <td>C103</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>13</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Saundercock, Mr. William Henry</td>\n", " <td>male</td>\n", " <td>20.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>A/5. 2151</td>\n", " <td>8.0500</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>14</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Andersson, Mr. Anders Johan</td>\n", " <td>male</td>\n", " <td>39.0</td>\n", " <td>1</td>\n", " <td>5</td>\n", " <td>347082</td>\n", " <td>31.2750</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>15</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Vestrom, Miss. Hulda Amanda Adolfina</td>\n", " <td>female</td>\n", " <td>14.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>350406</td>\n", " <td>7.8542</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>16</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>Hewlett, Mrs. (Mary D Kingcome)</td>\n", " <td>female</td>\n", " <td>55.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>248706</td>\n", " <td>16.0000</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>17</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Rice, Master. Eugene</td>\n", " <td>male</td>\n", " <td>2.0</td>\n", " <td>4</td>\n", " <td>1</td>\n", " <td>382652</td>\n", " <td>29.1250</td>\n", " <td>None</td>\n", " <td>Q</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>18</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>Williams, Mr. Charles Eugene</td>\n", " <td>male</td>\n", " <td>NaN</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>244373</td>\n", " <td>13.0000</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>19</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>Vander Planke, Mrs. Julius (Emelia Maria Vande...</td>\n", " <td>female</td>\n", " <td>31.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>345763</td>\n", " <td>18.0000</td>\n", " <td>None</td>\n", " <td>S</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>20</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>Masselmani, Mrs. Fatima</td>\n", " <td>female</td>\n", " <td>NaN</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>2649</td>\n", " <td>7.2250</td>\n", " <td>None</td>\n", " <td>C</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "5 6 0 3 \n", "6 7 0 1 \n", "7 8 0 3 \n", "8 9 1 3 \n", "9 10 1 2 \n", "10 11 1 3 \n", "11 12 1 1 \n", "12 13 0 3 \n", "13 14 0 3 \n", "14 15 0 3 \n", "15 16 1 2 \n", "16 17 0 3 \n", "17 18 1 2 \n", "18 19 0 3 \n", "19 20 1 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "5 Moran, Mr. James male NaN 0 \n", "6 McCarthy, Mr. Timothy J male 54.0 0 \n", "7 Palsson, Master. Gosta Leonard male 2.0 3 \n", "8 Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) female 27.0 0 \n", "9 Nasser, Mrs. Nicholas (Adele Achem) female 14.0 1 \n", "10 Sandstrom, Miss. Marguerite Rut female 4.0 1 \n", "11 Bonnell, Miss. Elizabeth female 58.0 0 \n", "12 Saundercock, Mr. William Henry male 20.0 0 \n", "13 Andersson, Mr. Anders Johan male 39.0 1 \n", "14 Vestrom, Miss. Hulda Amanda Adolfina female 14.0 0 \n", "15 Hewlett, Mrs. (Mary D Kingcome) female 55.0 0 \n", "16 Rice, Master. Eugene male 2.0 4 \n", "17 Williams, Mr. Charles Eugene male NaN 0 \n", "18 Vander Planke, Mrs. Julius (Emelia Maria Vande... female 31.0 1 \n", "19 Masselmani, Mrs. Fatima female NaN 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 None S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 None S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 None S \n", "5 0 330877 8.4583 None Q \n", "6 0 17463 51.8625 E46 S \n", "7 1 349909 21.0750 None S \n", "8 2 347742 11.1333 None S \n", "9 0 237736 30.0708 None C \n", "10 1 PP 9549 16.7000 G6 S \n", "11 0 113783 26.5500 C103 S \n", "12 0 A/5. 2151 8.0500 None S \n", "13 5 347082 31.2750 None S \n", "14 0 350406 7.8542 None S \n", "15 0 248706 16.0000 None S \n", "16 1 382652 29.1250 None Q \n", "17 0 244373 13.0000 None S \n", "18 0 345763 18.0000 None S \n", "19 0 2649 7.2250 None C " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.limit(20).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section we will explore missing data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Age', 0.19865319865319866),\n", " ('Cabin', 0.7710437710437711),\n", " ('Embarked', 0.002244668911335578)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "findMissingValuesCols(train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see almost 80% of Cabin column is missing data. So we will drop the Cabin column.\n", "Very few data is missing in Embarked column. We will just drop those rows." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import Imputer\n", "ageImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- PassengerId: integer (nullable = true)\n", " |-- Survived: integer (nullable = true)\n", " |-- Pclass: integer (nullable = true)\n", " |-- Name: string (nullable = true)\n", " |-- Sex: string (nullable = true)\n", " |-- Age: double (nullable = true)\n", " |-- SibSp: integer (nullable = true)\n", " |-- Parch: integer (nullable = true)\n", " |-- Ticket: string (nullable = true)\n", " |-- Fare: double (nullable = true)\n", " |-- Embarked: string (nullable = true)\n", "\n" ] }, { "data": { "text/plain": [ "889" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train = train.filter(train.Embarked.isNotNull())\n", "train = train.drop('Cabin')\n", "train.printSchema()\n", "train.count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exploratory Data Analysis\n", "In next few sections, we will explore training data and the relationship between different features and labels.\n", "As we already know, most of passengers in Titanic didn't survive. Our training data suggests the same, around one-third of the passengers survived. Same goes for passenger class and sex." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>340</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>549</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived count\n", "0 1 340\n", "1 0 549" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labelCol = 'Survived'\n", "train.groupby(labelCol).count().toPandas()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived_Sex</th>\n", " <th>female</th>\n", " <th>male</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>231</td>\n", " <td>109</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>81</td>\n", " <td>468</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived_Sex female male\n", "0 1 231 109\n", "1 0 81 468" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Sex').toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pointwise Mutual Information (PMI) is useful metric for exploring the relationship between two categorical features. PMI gives a scalar value for a pair of values in features. The value denotes the amount of information can be derived about other categori\n", "\n", "PMI can be normalized between [-1,+1], is called Normalized PMI, resulting in: \n", "* -1 (in the limit) for never occurring together,\n", "* 0 for independence,\n", "* +1 for complete co-occurrence\n", "\n", "\n", "`calcNormalizedPointwiseMutualInformation` function is implemented in `pysparkutils.py` file in `utils` directory." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Sex</th>\n", " <th>Survived</th>\n", " <th>Normalized PMI</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>2</th>\n", " <td>female</td>\n", " <td>0</td>\n", " <td>-0.361721</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>female</td>\n", " <td>1</td>\n", " <td>0.490151</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>male</td>\n", " <td>0</td>\n", " <td>0.424895</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>male</td>\n", " <td>1</td>\n", " <td>-0.336078</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Sex Survived Normalized PMI\n", "2 female 0 -0.361721\n", "1 female 1 0.490151\n", "0 male 0 0.424895\n", "3 male 1 -0.336078" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Sex', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Sex', labelCol)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived_Pclass</th>\n", " <th>1</th>\n", " <th>2</th>\n", " <th>3</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>134</td>\n", " <td>87</td>\n", " <td>119</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>80</td>\n", " <td>97</td>\n", " <td>372</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived_Pclass 1 2 3\n", "0 1 134 87 119\n", "1 0 80 97 372" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Pclass').toPandas()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Pclass</th>\n", " <th>Survived</th>\n", " <th>Normalized PMI</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>-0.208445</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>0.260544</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>-0.071421</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>2</td>\n", " <td>1</td>\n", " <td>0.091268</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>0.234674</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>-0.226840</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Pclass Survived Normalized PMI\n", "0 1 0 -0.208445\n", "2 1 1 0.260544\n", "4 2 0 -0.071421\n", "3 2 1 0.091268\n", "5 3 0 0.234674\n", "1 3 1 -0.226840" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Pclass', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Pclass', labelCol)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived_Embarked</th>\n", " <th>C</th>\n", " <th>Q</th>\n", " <th>S</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>93</td>\n", " <td>30</td>\n", " <td>217</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>75</td>\n", " <td>47</td>\n", " <td>427</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived_Embarked C Q S\n", "0 1 93 30 217\n", "1 0 75 47 427" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Embarked').toPandas()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Embarked</th>\n", " <th>Survived</th>\n", " <th>Normalized PMI</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>5</th>\n", " <td>C</td>\n", " <td>0</td>\n", " <td>-0.131229</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>C</td>\n", " <td>1</td>\n", " <td>0.163804</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Q</td>\n", " <td>0</td>\n", " <td>-0.003966</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>Q</td>\n", " <td>1</td>\n", " <td>0.005472</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>S</td>\n", " <td>0</td>\n", " <td>0.096935</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>S</td>\n", " <td>1</td>\n", " <td>-0.089810</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Embarked Survived Normalized PMI\n", "5 C 0 -0.131229\n", "3 C 1 0.163804\n", "4 Q 0 -0.003966\n", "0 Q 1 0.005472\n", "1 S 0 0.096935\n", "2 S 1 -0.089810" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Embarked', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Embarked', labelCol)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived_SibSp</th>\n", " <th>0</th>\n", " <th>1</th>\n", " <th>2</th>\n", " <th>3</th>\n", " <th>4</th>\n", " <th>5</th>\n", " <th>8</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>208</td>\n", " <td>112</td>\n", " <td>13</td>\n", " <td>4</td>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>398</td>\n", " <td>97</td>\n", " <td>15</td>\n", " <td>12</td>\n", " <td>15</td>\n", " <td>5</td>\n", " <td>7</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived_SibSp 0 1 2 3 4 5 8\n", "0 1 208 112 13 4 3 0 0\n", "1 0 398 97 15 12 15 5 7" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'SibSp').toPandas()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>SibSp</th>\n", " <th>Survived</th>\n", " <th>Normalized PMI</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0.076614</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>-0.074483</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>-0.128928</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>0.162829</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>-0.034825</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>2</td>\n", " <td>1</td>\n", " <td>0.045891</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>0.045135</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>-0.078675</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>4</td>\n", " <td>0</td>\n", " <td>0.073413</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>4</td>\n", " <td>1</td>\n", " <td>-0.145939</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>5</td>\n", " <td>0</td>\n", " <td>0.093038</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>8</td>\n", " <td>0</td>\n", " <td>0.099500</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " SibSp Survived Normalized PMI\n", "4 0 0 0.076614\n", "6 0 1 -0.074483\n", "0 1 0 -0.128928\n", "3 1 1 0.162829\n", "7 2 0 -0.034825\n", "5 2 1 0.045891\n", "10 3 0 0.045135\n", "1 3 1 -0.078675\n", "2 4 0 0.073413\n", "11 4 1 -0.145939\n", "8 5 0 0.093038\n", "9 8 0 0.099500" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'SibSp', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'SibSp', labelCol)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Survived_Parch</th>\n", " <th>0</th>\n", " <th>1</th>\n", " <th>2</th>\n", " <th>3</th>\n", " <th>4</th>\n", " <th>5</th>\n", " <th>6</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>231</td>\n", " <td>65</td>\n", " <td>40</td>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>445</td>\n", " <td>53</td>\n", " <td>40</td>\n", " <td>2</td>\n", " <td>4</td>\n", " <td>4</td>\n", " <td>1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Survived_Parch 0 1 2 3 4 5 6\n", "0 1 231 65 40 3 0 1 0\n", "1 0 445 53 40 2 4 4 1" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.crosstab(labelCol, 'Parch').toPandas()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Parch</th>\n", " <th>Survived</th>\n", " <th>Normalized PMI</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>5</th>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0.092309</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>-0.083569</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>-0.112913</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>0.139486</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>-0.068086</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>2</td>\n", " <td>1</td>\n", " <td>0.086419</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>-0.071231</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>0.079123</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>4</td>\n", " <td>0</td>\n", " <td>0.089196</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>5</td>\n", " <td>0</td>\n", " <td>0.047902</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>5</td>\n", " <td>1</td>\n", " <td>-0.095475</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>6</td>\n", " <td>0</td>\n", " <td>0.070986</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Parch Survived Normalized PMI\n", "5 0 0 0.092309\n", "7 0 1 -0.083569\n", "0 1 0 -0.112913\n", "4 1 1 0.139486\n", "8 2 0 -0.068086\n", "6 2 1 0.086419\n", "11 3 0 -0.071231\n", "1 3 1 0.079123\n", "3 4 0 0.089196\n", "9 5 0 0.047902\n", "10 5 1 -0.095475\n", "2 6 0 0.070986" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmis = calcNormalizedPointwiseMutualInformation(train, 'Parch', labelCol)\n", "toPandasDF(pmis, 'Normalized PMI', 'Parch', labelCol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will calculate the entropy of categorical features, which will give us the variance for categorical features." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Feature</th>\n", " <th>Entropy</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Sex</td>\n", " <td>0.934919</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Pclass</td>\n", " <td>0.907245</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Embarked</td>\n", " <td>0.692048</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>SibSp</td>\n", " <td>0.477435</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Parch</td>\n", " <td>0.402510</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Feature Entropy\n", "0 Sex 0.934919\n", "1 Pclass 0.907245\n", "2 Embarked 0.692048\n", "3 SibSp 0.477435\n", "4 Parch 0.402510" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "columns = ['Sex', 'Pclass', 'Embarked', 'SibSp', 'Parch']\n", "entropies = calcNormalizedEntropy(train, *columns)\n", "dictToPandasDF(entropies, 'Feature', 'Entropy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical feature independence test via chi square test." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pValues: [0.0,0.0,6.02813466444e-06,4.02603175464e-07,4.93843854699e-11,0.0101897422598,0.0315461645121,4.25298058386e-05,0.00150622342036,0.612884928604,0.116808580457]\n", "degreesOfFreedom: [2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", "statistics: [100.980407261,260.756342249,20.4792462347,25.6818141585,43.2014376768,6.60142083307,4.62299205997,16.755010532,10.0709867693,0.255995235761,2.45959925254]\n" ] } ], "source": [ "from pyspark.ml import Pipeline\n", "from pyspark.ml.stat import ChiSquareTest\n", "from pyspark.ml.feature import Bucketizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler, VectorIndexer\n", "\n", "edaEmbarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked')\n", "edaSexIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex')\n", "\n", "edaAgeImputer = Imputer(inputCols=['Age'], outputCols=['imputedAge'], strategy='median')\n", "\n", "ageSplits = [0, 16, 32, 48, 64, 200]\n", "edaAgeBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge')\n", "\n", "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n", "edaFareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare')\n", "\n", "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n", " outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n", "inputCols=['Pclass', 'oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge']\n", "edaAssembler = VectorAssembler(inputCols=inputCols, outputCol='features')\n", "\n", "pipeline = Pipeline(stages=[edaEmbarkedIndexer, edaSexIndexer, edaAgeImputer, edaAgeBucketizer, \n", " edaFareBucketizer, oneHotEncoderEstimator, edaAssembler])\n", "chiSqTrain = pipeline.fit(train).transform(train)\n", "\n", "r = ChiSquareTest.test(chiSqTrain, 'features', 'Survived').head()\n", "print(\"pValues: \" + str(r.pValues))\n", "print(\"degreesOfFreedom: \" + str(r.degreesOfFreedom))\n", "print(\"statistics: \" + str(r.statistics))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer\n", "\n", "embarkedIndexer = StringIndexer(inputCol='Embarked', outputCol='indexedEmbarked', handleInvalid='skip')\n", "sexFeatureIndexer = StringIndexer(inputCol='Sex', outputCol='indexedSex', handleInvalid='skip')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import Bucketizer\n", "\n", "ageSplits = [0, 16, 32, 48, 64, 200]\n", "ageBucketizer = Bucketizer(splits=ageSplits, inputCol='imputedAge', outputCol='bucketedAge', handleInvalid='skip')\n", "fareSplits = [-float('inf'), 7.91, 14.454, 31, float('inf')]\n", "fareBucketizer = Bucketizer(splits=fareSplits, inputCol='Fare', outputCol='bucketedFare', handleInvalid='skip')" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.feature import OneHotEncoderEstimator, VectorIndexer\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.ml.classification import RandomForestClassifier\n", "\n", "oneHotEncoderEstimator = OneHotEncoderEstimator(inputCols=['indexedSex', 'indexedEmbarked', 'bucketedFare', 'bucketedAge'], \n", " outputCols=['oneHotSex', 'oneHotEmbarked','oneHotFare', 'oneHotAge'])\n", "assembler = VectorAssembler(inputCols=['Pclass', 'SibSp', 'Parch', 'bucketedAge', \n", " 'bucketedFare', 'indexedEmbarked', 'indexedSex'], outputCol='features')\n", "rf = RandomForestClassifier(labelCol=labelCol, featuresCol='features')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n", "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", "from pyspark.ml import Pipeline\n", "\n", "pipeline = Pipeline(stages=[ageImputer, embarkedIndexer, sexFeatureIndexer, ageBucketizer, \n", " fareBucketizer, oneHotEncoderEstimator, assembler, rf])\n", "\n", "grid = ParamGridBuilder().addGrid(rf.numTrees, [15, 20, 25, 30])\\\n", " .addGrid(rf.maxDepth, [5, 8])\\\n", " .build()\n", "\n", "cv = CrossValidator(estimator=pipeline, \n", " estimatorParamMaps=grid, \n", " evaluator=BinaryClassificationEvaluator(labelCol=labelCol, metricName='areaUnderROC'), \n", " numFolds=10)\n", "\n", "model = cv.fit(train)\n", "train = model.transform(train)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9265509482481523" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluator = model.getEvaluator()\n", "evaluator.evaluate(train)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "test = model.transform(test)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>PassengerId</th>\n", " <th>Pclass</th>\n", " <th>Name</th>\n", " <th>Sex</th>\n", " <th>Age</th>\n", " <th>SibSp</th>\n", " <th>Parch</th>\n", " <th>Ticket</th>\n", " <th>Fare</th>\n", " <th>Cabin</th>\n", " <th>...</th>\n", " <th>bucketedAge</th>\n", " <th>bucketedFare</th>\n", " <th>oneHotSex</th>\n", " <th>oneHotEmbarked</th>\n", " <th>oneHotFare</th>\n", " <th>oneHotAge</th>\n", " <th>features</th>\n", " <th>rawPrediction</th>\n", " <th>probability</th>\n", " <th>prediction</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>892</td>\n", " <td>3</td>\n", " <td>Kelly, Mr. James</td>\n", " <td>male</td>\n", " <td>34.5</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>330911</td>\n", " <td>7.8292</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>0.0</td>\n", " <td>(1.0)</td>\n", " <td>(0.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>(3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0)</td>\n", " <td>[29.2960574888, 0.703942511236]</td>\n", " <td>[0.976535249625, 0.0234647503745]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>893</td>\n", " <td>3</td>\n", " <td>Wilkes, Mrs. James (Ellen Needs)</td>\n", " <td>female</td>\n", " <td>47.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>363272</td>\n", " <td>7.0000</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>0.0</td>\n", " <td>(0.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>[3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0]</td>\n", " <td>[17.9363522365, 12.0636477635]</td>\n", " <td>[0.597878407882, 0.402121592118]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>894</td>\n", " <td>2</td>\n", " <td>Myles, Mr. Thomas Francis</td>\n", " <td>male</td>\n", " <td>62.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>240276</td>\n", " <td>9.6875</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>3.0</td>\n", " <td>1.0</td>\n", " <td>(1.0)</td>\n", " <td>(0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 0.0, 1.0)</td>\n", " <td>[2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0]</td>\n", " <td>[26.2911066277, 3.70889337228]</td>\n", " <td>[0.876370220924, 0.123629779076]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>895</td>\n", " <td>3</td>\n", " <td>Wirz, Mr. Albert</td>\n", " <td>male</td>\n", " <td>27.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>315154</td>\n", " <td>8.6625</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>(3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0)</td>\n", " <td>[25.9086506078, 4.09134939223]</td>\n", " <td>[0.863621686926, 0.136378313074]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>896</td>\n", " <td>3</td>\n", " <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n", " <td>female</td>\n", " <td>22.0</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3101298</td>\n", " <td>12.2875</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>(0.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0]</td>\n", " <td>[21.2731321988, 8.72686780122]</td>\n", " <td>[0.709104406626, 0.290895593374]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>897</td>\n", " <td>3</td>\n", " <td>Svensson, Mr. Johan Cervin</td>\n", " <td>male</td>\n", " <td>14.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>7538</td>\n", " <td>9.2250</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0, 0.0)</td>\n", " <td>(3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[21.79357844, 8.20642155996]</td>\n", " <td>[0.726452614668, 0.273547385332]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>898</td>\n", " <td>3</td>\n", " <td>Connolly, Miss. Kate</td>\n", " <td>female</td>\n", " <td>30.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>330972</td>\n", " <td>7.6292</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>(0.0)</td>\n", " <td>(0.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0]</td>\n", " <td>[6.58546403436, 23.4145359656]</td>\n", " <td>[0.219515467812, 0.780484532188]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>899</td>\n", " <td>2</td>\n", " <td>Caldwell, Mr. Albert Francis</td>\n", " <td>male</td>\n", " <td>26.0</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>248738</td>\n", " <td>29.0000</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0]</td>\n", " <td>[26.9219114219, 3.07808857809]</td>\n", " <td>[0.897397047397, 0.102602952603]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>900</td>\n", " <td>3</td>\n", " <td>Abrahim, Mrs. Joseph (Sophie Halaut Easu)</td>\n", " <td>female</td>\n", " <td>18.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>2657</td>\n", " <td>7.2292</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>(0.0)</td>\n", " <td>(0.0, 1.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0]</td>\n", " <td>[5.95376028838, 24.0462397116]</td>\n", " <td>[0.198458676279, 0.801541323721]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>901</td>\n", " <td>3</td>\n", " <td>Davies, Mr. John Samuel</td>\n", " <td>male</td>\n", " <td>21.0</td>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>A/4 48871</td>\n", " <td>24.1500</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0]</td>\n", " <td>[27.683257006, 2.31674299405]</td>\n", " <td>[0.922775233532, 0.0772247664683]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>902</td>\n", " <td>3</td>\n", " <td>Ilieff, Mr. Ylio</td>\n", " <td>male</td>\n", " <td>NaN</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>349220</td>\n", " <td>7.8958</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>(3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0)</td>\n", " <td>[27.1796680333, 2.8203319667]</td>\n", " <td>[0.905988934443, 0.0940110655565]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>903</td>\n", " <td>1</td>\n", " <td>Jones, Mr. Charles Cresson</td>\n", " <td>male</td>\n", " <td>46.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>694</td>\n", " <td>26.0000</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>2.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>(1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0)</td>\n", " <td>[14.2653782322, 15.7346217678]</td>\n", " <td>[0.475512607739, 0.524487392261]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>904</td>\n", " <td>1</td>\n", " <td>Snyder, Mrs. John Pillsbury (Nelle Stevenson)</td>\n", " <td>female</td>\n", " <td>23.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>21228</td>\n", " <td>82.2667</td>\n", " <td>B45</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>(0.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0]</td>\n", " <td>[0.295454545455, 29.7045454545]</td>\n", " <td>[0.00984848484848, 0.990151515152]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>905</td>\n", " <td>2</td>\n", " <td>Howard, Mr. Benjamin</td>\n", " <td>male</td>\n", " <td>63.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>24065</td>\n", " <td>26.0000</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>3.0</td>\n", " <td>2.0</td>\n", " <td>(1.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0)</td>\n", " <td>(0.0, 0.0, 0.0, 1.0)</td>\n", " <td>[2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0]</td>\n", " <td>[27.5151896549, 2.4848103451]</td>\n", " <td>[0.917172988497, 0.0828270115032]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>906</td>\n", " <td>1</td>\n", " <td>Chaffee, Mrs. Herbert Fuller (Carrie Constance...</td>\n", " <td>female</td>\n", " <td>47.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>W.E.P. 5734</td>\n", " <td>61.1750</td>\n", " <td>E31</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>3.0</td>\n", " <td>(0.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>[1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0]</td>\n", " <td>[0.0, 30.0]</td>\n", " <td>[0.0, 1.0]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>907</td>\n", " <td>2</td>\n", " <td>del Carlo, Mrs. Sebastiano (Argenia Genovesi)</td>\n", " <td>female</td>\n", " <td>24.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>SC/PARIS 2167</td>\n", " <td>27.7208</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>(0.0)</td>\n", " <td>(0.0, 1.0)</td>\n", " <td>(0.0, 0.0, 1.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0]</td>\n", " <td>[2.02801517479, 27.9719848252]</td>\n", " <td>[0.0676005058263, 0.932399494174]</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>908</td>\n", " <td>2</td>\n", " <td>Keane, Mr. Daniel</td>\n", " <td>male</td>\n", " <td>35.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>233734</td>\n", " <td>12.3500</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>1.0</td>\n", " <td>(1.0)</td>\n", " <td>(0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>[2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0]</td>\n", " <td>[26.5861882672, 3.41381173276]</td>\n", " <td>[0.886206275575, 0.113793724425]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>909</td>\n", " <td>3</td>\n", " <td>Assaf, Mr. Gerios</td>\n", " <td>male</td>\n", " <td>21.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>2692</td>\n", " <td>7.2250</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>(1.0)</td>\n", " <td>(0.0, 1.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>(3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0)</td>\n", " <td>[25.8019873912, 4.19801260875]</td>\n", " <td>[0.860066246375, 0.139933753625]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>18</th>\n", " <td>910</td>\n", " <td>3</td>\n", " <td>Ilmakangas, Miss. Ida Livija</td>\n", " <td>female</td>\n", " <td>27.0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>STON/O2. 3101270</td>\n", " <td>7.9250</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>(0.0)</td>\n", " <td>(1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0)</td>\n", " <td>(0.0, 1.0, 0.0, 0.0)</td>\n", " <td>[3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0]</td>\n", " <td>[20.8412096143, 9.15879038571]</td>\n", " <td>[0.694706987143, 0.305293012857]</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>911</td>\n", " <td>3</td>\n", " <td>\"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\"</td>\n", " <td>female</td>\n", " <td>45.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>2696</td>\n", " <td>7.2250</td>\n", " <td>None</td>\n", " <td>...</td>\n", " <td>2.0</td>\n", " <td>0.0</td>\n", " <td>(0.0)</td>\n", " <td>(0.0, 1.0)</td>\n", " <td>(1.0, 0.0, 0.0)</td>\n", " <td>(0.0, 0.0, 1.0, 0.0)</td>\n", " <td>[3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0]</td>\n", " <td>[17.3386478658, 12.6613521342]</td>\n", " <td>[0.577954928861, 0.422045071139]</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>20 rows × 24 columns</p>\n", "</div>" ], "text/plain": [ " PassengerId Pclass Name \\\n", "0 892 3 Kelly, Mr. James \n", "1 893 3 Wilkes, Mrs. James (Ellen Needs) \n", "2 894 2 Myles, Mr. Thomas Francis \n", "3 895 3 Wirz, Mr. Albert \n", "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) \n", "5 897 3 Svensson, Mr. Johan Cervin \n", "6 898 3 Connolly, Miss. Kate \n", "7 899 2 Caldwell, Mr. Albert Francis \n", "8 900 3 Abrahim, Mrs. Joseph (Sophie Halaut Easu) \n", "9 901 3 Davies, Mr. John Samuel \n", "10 902 3 Ilieff, Mr. Ylio \n", "11 903 1 Jones, Mr. Charles Cresson \n", "12 904 1 Snyder, Mrs. John Pillsbury (Nelle Stevenson) \n", "13 905 2 Howard, Mr. Benjamin \n", "14 906 1 Chaffee, Mrs. Herbert Fuller (Carrie Constance... \n", "15 907 2 del Carlo, Mrs. Sebastiano (Argenia Genovesi) \n", "16 908 2 Keane, Mr. Daniel \n", "17 909 3 Assaf, Mr. Gerios \n", "18 910 3 Ilmakangas, Miss. Ida Livija \n", "19 911 3 \"Assaf Khalil, Mrs. Mariana (Miriam\"\")\"\"\" \n", "\n", " Sex Age SibSp Parch Ticket Fare Cabin ... \\\n", "0 male 34.5 0 0 330911 7.8292 None ... \n", "1 female 47.0 1 0 363272 7.0000 None ... \n", "2 male 62.0 0 0 240276 9.6875 None ... \n", "3 male 27.0 0 0 315154 8.6625 None ... \n", "4 female 22.0 1 1 3101298 12.2875 None ... \n", "5 male 14.0 0 0 7538 9.2250 None ... \n", "6 female 30.0 0 0 330972 7.6292 None ... \n", "7 male 26.0 1 1 248738 29.0000 None ... \n", "8 female 18.0 0 0 2657 7.2292 None ... \n", "9 male 21.0 2 0 A/4 48871 24.1500 None ... \n", "10 male NaN 0 0 349220 7.8958 None ... \n", "11 male 46.0 0 0 694 26.0000 None ... \n", "12 female 23.0 1 0 21228 82.2667 B45 ... \n", "13 male 63.0 1 0 24065 26.0000 None ... \n", "14 female 47.0 1 0 W.E.P. 5734 61.1750 E31 ... \n", "15 female 24.0 1 0 SC/PARIS 2167 27.7208 None ... \n", "16 male 35.0 0 0 233734 12.3500 None ... \n", "17 male 21.0 0 0 2692 7.2250 None ... \n", "18 female 27.0 1 0 STON/O2. 3101270 7.9250 None ... \n", "19 female 45.0 0 0 2696 7.2250 None ... \n", "\n", " bucketedAge bucketedFare oneHotSex oneHotEmbarked oneHotFare \\\n", "0 2.0 0.0 (1.0) (0.0, 0.0) (1.0, 0.0, 0.0) \n", "1 2.0 0.0 (0.0) (1.0, 0.0) (1.0, 0.0, 0.0) \n", "2 3.0 1.0 (1.0) (0.0, 0.0) (0.0, 1.0, 0.0) \n", "3 1.0 1.0 (1.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "4 1.0 1.0 (0.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "5 0.0 1.0 (1.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "6 1.0 0.0 (0.0) (0.0, 0.0) (1.0, 0.0, 0.0) \n", "7 1.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "8 1.0 0.0 (0.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "9 1.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "10 1.0 0.0 (1.0) (1.0, 0.0) (1.0, 0.0, 0.0) \n", "11 2.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "12 1.0 3.0 (0.0) (1.0, 0.0) (0.0, 0.0, 0.0) \n", "13 3.0 2.0 (1.0) (1.0, 0.0) (0.0, 0.0, 1.0) \n", "14 2.0 3.0 (0.0) (1.0, 0.0) (0.0, 0.0, 0.0) \n", "15 1.0 2.0 (0.0) (0.0, 1.0) (0.0, 0.0, 1.0) \n", "16 2.0 1.0 (1.0) (0.0, 0.0) (0.0, 1.0, 0.0) \n", "17 1.0 0.0 (1.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "18 1.0 1.0 (0.0) (1.0, 0.0) (0.0, 1.0, 0.0) \n", "19 2.0 0.0 (0.0) (0.0, 1.0) (1.0, 0.0, 0.0) \n", "\n", " oneHotAge features \\\n", "0 (0.0, 0.0, 1.0, 0.0) (3.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0) \n", "1 (0.0, 0.0, 1.0, 0.0) [3.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0] \n", "2 (0.0, 0.0, 0.0, 1.0) [2.0, 0.0, 0.0, 3.0, 1.0, 2.0, 0.0] \n", "3 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0) \n", "4 (0.0, 1.0, 0.0, 0.0) [3.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0] \n", "5 (1.0, 0.0, 0.0, 0.0) (3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0) \n", "6 (0.0, 1.0, 0.0, 0.0) [3.0, 0.0, 0.0, 1.0, 0.0, 2.0, 1.0] \n", "7 (0.0, 1.0, 0.0, 0.0) [2.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0] \n", "8 (0.0, 1.0, 0.0, 0.0) [3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0] \n", "9 (0.0, 1.0, 0.0, 0.0) [3.0, 2.0, 0.0, 1.0, 2.0, 0.0, 0.0] \n", "10 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0) \n", "11 (0.0, 0.0, 1.0, 0.0) (1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0) \n", "12 (0.0, 1.0, 0.0, 0.0) [1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 1.0] \n", "13 (0.0, 0.0, 0.0, 1.0) [2.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0] \n", "14 (0.0, 0.0, 1.0, 0.0) [1.0, 1.0, 0.0, 2.0, 3.0, 0.0, 1.0] \n", "15 (0.0, 1.0, 0.0, 0.0) [2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 1.0] \n", "16 (0.0, 0.0, 1.0, 0.0) [2.0, 0.0, 0.0, 2.0, 1.0, 2.0, 0.0] \n", "17 (0.0, 1.0, 0.0, 0.0) (3.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0) \n", "18 (0.0, 1.0, 0.0, 0.0) [3.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0] \n", "19 (0.0, 0.0, 1.0, 0.0) [3.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0] \n", "\n", " rawPrediction probability \\\n", "0 [29.2960574888, 0.703942511236] [0.976535249625, 0.0234647503745] \n", "1 [17.9363522365, 12.0636477635] [0.597878407882, 0.402121592118] \n", "2 [26.2911066277, 3.70889337228] [0.876370220924, 0.123629779076] \n", "3 [25.9086506078, 4.09134939223] [0.863621686926, 0.136378313074] \n", "4 [21.2731321988, 8.72686780122] [0.709104406626, 0.290895593374] \n", "5 [21.79357844, 8.20642155996] [0.726452614668, 0.273547385332] \n", "6 [6.58546403436, 23.4145359656] [0.219515467812, 0.780484532188] \n", "7 [26.9219114219, 3.07808857809] [0.897397047397, 0.102602952603] \n", "8 [5.95376028838, 24.0462397116] [0.198458676279, 0.801541323721] \n", "9 [27.683257006, 2.31674299405] [0.922775233532, 0.0772247664683] \n", "10 [27.1796680333, 2.8203319667] [0.905988934443, 0.0940110655565] \n", "11 [14.2653782322, 15.7346217678] [0.475512607739, 0.524487392261] \n", "12 [0.295454545455, 29.7045454545] [0.00984848484848, 0.990151515152] \n", "13 [27.5151896549, 2.4848103451] [0.917172988497, 0.0828270115032] \n", "14 [0.0, 30.0] [0.0, 1.0] \n", "15 [2.02801517479, 27.9719848252] [0.0676005058263, 0.932399494174] \n", "16 [26.5861882672, 3.41381173276] [0.886206275575, 0.113793724425] \n", "17 [25.8019873912, 4.19801260875] [0.860066246375, 0.139933753625] \n", "18 [20.8412096143, 9.15879038571] [0.694706987143, 0.305293012857] \n", "19 [17.3386478658, 12.6613521342] [0.577954928861, 0.422045071139] \n", "\n", " prediction \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "5 0.0 \n", "6 1.0 \n", "7 0.0 \n", "8 1.0 \n", "9 0.0 \n", "10 0.0 \n", "11 1.0 \n", "12 1.0 \n", "13 0.0 \n", "14 1.0 \n", "15 1.0 \n", "16 0.0 \n", "17 0.0 \n", "18 0.0 \n", "19 0.0 \n", "\n", "[20 rows x 24 columns]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test.limit(20).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Write the predictions to CSV file in Kaggle specified format." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql.types import IntegerType\n", "\n", "csvPath = 'prediction.csv'\n", "test.select('PassengerId', 'prediction')\\\n", " .coalesce(1)\\\n", " .withColumn('Survived', test['prediction'].cast(IntegerType()))\\\n", " .drop('prediction')\\\n", " .write.csv(csvPath, header='true', mode='ignore')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }