{"cells":[{"cell_type":"code","source":["from pyspark.sql import SparkSession"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"e79647af-d6b4-4484-9e41-17af77f62bea"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["spark = SparkSession.builder.appName('myproj').getOrCreate()"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"e0b05115-2da9-4cf2-9842-074dc689c9e3"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["data = spark.read.csv(\"dbfs:/FileStore/shared_uploads/dizhen@hsph.harvard.edu/titanic.csv\",inferSchema=True,header=True)"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"7f8ed7d3-16e2-4a17-94f4-6faacd9c8e53"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["data.printSchema()"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"390a5dc0-44ab-42ee-aabd-71991dd1bfdf"}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"root\n |-- PassengerId: integer (nullable = true)\n |-- Survived: integer (nullable = true)\n |-- Pclass: integer (nullable = true)\n |-- Name: string (nullable = true)\n |-- Sex: string (nullable = true)\n |-- Age: double (nullable = true)\n |-- SibSp: integer (nullable = true)\n |-- Parch: integer (nullable = true)\n |-- Ticket: string (nullable = true)\n |-- Fare: double (nullable = true)\n |-- Cabin: string (nullable = true)\n |-- Embarked: string (nullable = true)\n\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["root\n |-- PassengerId: integer (nullable = true)\n |-- Survived: integer (nullable = true)\n |-- Pclass: integer (nullable = true)\n |-- Name: string (nullable = true)\n |-- Sex: string (nullable = true)\n |-- Age: double (nullable = true)\n |-- SibSp: integer (nullable = true)\n |-- Parch: integer (nullable = true)\n |-- Ticket: string (nullable = true)\n |-- Fare: double (nullable = true)\n |-- Cabin: string (nullable = true)\n |-- Embarked: string (nullable = true)\n\n"]}}],"execution_count":0},{"cell_type":"code","source":["data.columns"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"d1cb3b3c-643a-4987-9ed5-2f93d4dfc6fe"}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[5]: ['PassengerId',\n 'Survived',\n 'Pclass',\n 'Name',\n 'Sex',\n 'Age',\n 'SibSp',\n 'Parch',\n 'Ticket',\n 'Fare',\n 'Cabin',\n 'Embarked']","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[5]: ['PassengerId',\n 'Survived',\n 'Pclass',\n 'Name',\n 'Sex',\n 'Age',\n 'SibSp',\n 'Parch',\n 'Ticket',\n 'Fare',\n 'Cabin',\n 'Embarked']"]}}],"execution_count":0},{"cell_type":"code","source":["my_cols = data.select(['Survived',\n 'Pclass',\n 'Sex',\n 'Age',\n 'SibSp',\n 'Parch',\n 'Fare',\n 'Embarked'])"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"d5ac3155-492f-4910-89b5-a89f45eeaba5"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["my_final_data = my_cols.na.drop()"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"50a490db-3cb6-436f-9aeb-af3ce6460bed"}},"outputs":[],"execution_count":0},{"cell_type":"markdown","source":["Handle categorical features"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"30065e37-4030-4328-b1f2-1322016bd880"}}},{"cell_type":"code","source":["from pyspark.ml.feature import (VectorAssembler,VectorIndexer,\n OneHotEncoder,StringIndexer)"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"9fa2c9ed-08b1-42a7-bf50-0cd49587ea25"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["gender_indexer = StringIndexer(inputCol='Sex',outputCol='SexIndex')\ngender_encoder = OneHotEncoder(inputCol='SexIndex',outputCol='SexVec')"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"842b61b2-329d-4591-8716-b12303894d52"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["embark_indexer = StringIndexer(inputCol='Embarked',outputCol='EmbarkIndex')\nembark_encoder = OneHotEncoder(inputCol='EmbarkIndex',outputCol='EmbarkVec')"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"0338b343-c97d-43f3-a108-2368a6d4acc4"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["assembler = VectorAssembler(inputCols=['Pclass',\n 'SexVec',\n 'Age',\n 'SibSp',\n 'Parch',\n 'Fare',\n 'EmbarkVec'],outputCol='features')"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"81e8854d-ed8e-4742-8eec-da9b8317f83f"}},"outputs":[],"execution_count":0},{"cell_type":"markdown","source":["Pipeline"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"cd0cae6c-6b06-4d1f-ab60-7237d5d0a0d9"}}},{"cell_type":"code","source":["from pyspark.ml.classification import LogisticRegression\nfrom pyspark.ml import Pipeline"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"45a5ca47-5e06-49d7-a541-0d7179ce21ce"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["log_reg_titanic = LogisticRegression(featuresCol='features',labelCol='Survived')"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"086d34b6-de8e-4e58-a812-83c3b03be960"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["pipeline = Pipeline(stages=[gender_indexer,embark_indexer,\n gender_encoder,embark_encoder,\n assembler,log_reg_titanic])"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"aeb77946-c52d-4a79-be22-4fa6ef686698"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["train_titanic_data, test_titanic_data = my_final_data.randomSplit([0.7,.3])"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"cc6fe1ac-65af-4129-8be2-e220cfebef44"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["fit_model = pipeline.fit(train_titanic_data)"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"df5441cc-54fd-4cf3-a596-26541b0756fa"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["results = fit_model.transform(test_titanic_data)"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"4ee82d8f-bf72-4a3f-a208-c0c94da5370d"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["from pyspark.ml.evaluation import BinaryClassificationEvaluator"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"3dccc143-beed-4b52-b1ae-eca5c019e291"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["my_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction',\n labelCol='Survived')"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"ae7f0222-5d71-4fd2-aea3-8aa597ad53ad"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["results.select('Survived','prediction').show()"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"1cd0efda-b5fe-4a68-a7ae-0b929e0d5ef1"}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"+--------+----------+\n|Survived|prediction|\n+--------+----------+\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n+--------+----------+\nonly showing top 20 rows\n\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["+--------+----------+\n|Survived|prediction|\n+--------+----------+\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 1.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 1.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n| 0| 0.0|\n+--------+----------+\nonly showing top 20 rows\n\n"]}}],"execution_count":0},{"cell_type":"code","source":["AUC = my_eval.evaluate(results)"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"82782a6d-8d88-4a35-906f-cf3bf5cb37bb"}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["AUC"],"metadata":{"application/vnd.databricks.v1+cell":{"title":"","showTitle":false,"inputWidgets":{},"nuid":"080b8e72-a117-4be4-950f-45e86602478a"}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[24]: 0.7758413461538461","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[24]: 0.7758413461538461"]}}],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"ml-logistic-regression","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":3598965581908874}},"nbformat":4,"nbformat_minor":0}