{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Example: Covertype Data Set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following example uses the (processed) Covertype dataset from [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Covertype).\n", "\n", "It is a dataset with both categorical (`wilderness_area` and `soil_type`) and continuous (the rest) features. The target is the `cover_type` column:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "from pyspark import SparkConf, SparkContext" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "conf = SparkConf() #.set(\"spark.jars\", \"/Users/per0/wa/spark_wa/spark-tree-plotting/target/scala-2.11/spark-tree-plotting_0.2.jar\")\n", "\n", "sc = SparkContext.getOrCreate()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "spark = SparkSession.builder.getOrCreate()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- elevation: long (nullable = true)\n", " |-- aspect: long (nullable = true)\n", " |-- slope: long (nullable = true)\n", " |-- horizontal_distance_to_hydrology: long (nullable = true)\n", " |-- vertical_distance_to_hydrology: long (nullable = true)\n", " |-- horizontal_distance_to_roadways: long (nullable = true)\n", " |-- hillshade_9am: long (nullable = true)\n", " |-- hillshade_noon: long (nullable = true)\n", " |-- hillshade_3pm: long (nullable = true)\n", " |-- horizontal_distance_to_fire_points: long (nullable = true)\n", " |-- wilderness_area: string (nullable = true)\n", " |-- soil_type: string (nullable = true)\n", " |-- cover_type: string (nullable = true)\n", "\n" ] } ], "source": [ "covertype_dataset = spark.read.parquet(\"covertype_dataset.snappy.parquet\")\n", "\n", "covertype_dataset.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The 10 first rows:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
elevationaspectslopehorizontal_distance_to_hydrologyvertical_distance_to_hydrologyhorizontal_distance_to_roadwayshillshade_9amhillshade_noonhillshade_3pmhorizontal_distance_to_fire_pointswilderness_areasoil_typecover_type
0259651325805102212321486279rawah wilderness areaSoil_Type_7745Aspen
12590562212-63902202351516225rawah wilderness areaSoil_Type_7745Aspen
2280413992686531802342381356121rawah wilderness areaSoil_Type_4744Lodgepole Pine
327851551824211830902382381226211rawah wilderness areaSoil_Type_7746Lodgepole Pine
42595452153-13912202341506172rawah wilderness areaSoil_Type_7745Aspen
525791326300-15672302371406031rawah wilderness areaSoil_Type_7745Lodgepole Pine
6260645727056332222251386256rawah wilderness areaSoil_Type_7745Aspen
7260549423475732222301446228rawah wilderness areaSoil_Type_7745Aspen
82617459240566662232211336244rawah wilderness areaSoil_Type_7745Aspen
926125910247116362282191246230rawah wilderness areaSoil_Type_7745Aspen
\n", "
" ], "text/plain": [ " elevation aspect slope horizontal_distance_to_hydrology \\\n", "0 2596 51 3 258 \n", "1 2590 56 2 212 \n", "2 2804 139 9 268 \n", "3 2785 155 18 242 \n", "4 2595 45 2 153 \n", "5 2579 132 6 300 \n", "6 2606 45 7 270 \n", "7 2605 49 4 234 \n", "8 2617 45 9 240 \n", "9 2612 59 10 247 \n", "\n", " vertical_distance_to_hydrology horizontal_distance_to_roadways \\\n", "0 0 510 \n", "1 -6 390 \n", "2 65 3180 \n", "3 118 3090 \n", "4 -1 391 \n", "5 -15 67 \n", "6 5 633 \n", "7 7 573 \n", "8 56 666 \n", "9 11 636 \n", "\n", " hillshade_9am hillshade_noon hillshade_3pm \\\n", "0 221 232 148 \n", "1 220 235 151 \n", "2 234 238 135 \n", "3 238 238 122 \n", "4 220 234 150 \n", "5 230 237 140 \n", "6 222 225 138 \n", "7 222 230 144 \n", "8 223 221 133 \n", "9 228 219 124 \n", "\n", " horizontal_distance_to_fire_points wilderness_area soil_type \\\n", "0 6279 rawah wilderness area Soil_Type_7745 \n", "1 6225 rawah wilderness area Soil_Type_7745 \n", "2 6121 rawah wilderness area Soil_Type_4744 \n", "3 6211 rawah wilderness area Soil_Type_7746 \n", "4 6172 rawah wilderness area Soil_Type_7745 \n", "5 6031 rawah wilderness area Soil_Type_7745 \n", "6 6256 rawah wilderness area Soil_Type_7745 \n", "7 6228 rawah wilderness area Soil_Type_7745 \n", "8 6244 rawah wilderness area Soil_Type_7745 \n", "9 6230 rawah wilderness area Soil_Type_7745 \n", "\n", " cover_type \n", "0 Aspen \n", "1 Aspen \n", "2 Lodgepole Pine \n", "3 Lodgepole Pine \n", "4 Aspen \n", "5 Lodgepole Pine \n", "6 Aspen \n", "7 Aspen \n", "8 Aspen \n", "9 Aspen " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "covertype_dataset.limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order for Spark's `DecisionTreeClassifier` to work with the categorical features (as well as the target), we first need to use [`pyspark.ml.feature.StringIndexer`](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer)s to generate a numeric representation for those columns:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from pyspark.ml.feature import StringIndexer\n", "\n", "string_indexer_wilderness = StringIndexer(inputCol=\"wilderness_area\", outputCol=\"wilderness_area_indexed\")\n", "\n", "string_indexer_soil = StringIndexer(inputCol=\"soil_type\", outputCol=\"soil_type_indexed\")\n", " \n", "string_indexer_cover = StringIndexer(inputCol=\"cover_type\", outputCol=\"cover_type_indexed\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To generate the new *StringIndexerModels*, we call `.fit()` on each `StringIndexer` instance:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "string_indexer_wilderness_model = string_indexer_wilderness.fit(covertype_dataset)\n", "\n", "string_indexer_soil_model = string_indexer_soil.fit(covertype_dataset)\n", "\n", "string_indexer_cover_model = string_indexer_cover.fit(covertype_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we create the new columns:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "covertype_dataset_indexed_features = string_indexer_cover_model.transform(string_indexer_soil_model\n", " .transform(string_indexer_wilderness_model\n", " .transform(covertype_dataset)\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "New columns can be seen at the right:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
elevationaspectslopehorizontal_distance_to_hydrologyvertical_distance_to_hydrologyhorizontal_distance_to_roadwayshillshade_9amhillshade_noonhillshade_3pmhorizontal_distance_to_fire_pointswilderness_areasoil_typecover_typewilderness_area_indexedsoil_type_indexedcover_type_indexed
0259651325805102212321486279rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
12590562212-63902202351516225rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
2280413992686531802342381356121rawah wilderness areaSoil_Type_4744Lodgepole Pine0.07.00.0
327851551824211830902382381226211rawah wilderness areaSoil_Type_7746Lodgepole Pine0.06.00.0
42595452153-13912202341506172rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
525791326300-15672302371406031rawah wilderness areaSoil_Type_7745Lodgepole Pine0.00.00.0
6260645727056332222251386256rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
7260549423475732222301446228rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
82617459240566662232211336244rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
926125910247116362282191246230rawah wilderness areaSoil_Type_7745Aspen0.00.05.0
\n", "
" ], "text/plain": [ " elevation aspect slope horizontal_distance_to_hydrology \\\n", "0 2596 51 3 258 \n", "1 2590 56 2 212 \n", "2 2804 139 9 268 \n", "3 2785 155 18 242 \n", "4 2595 45 2 153 \n", "5 2579 132 6 300 \n", "6 2606 45 7 270 \n", "7 2605 49 4 234 \n", "8 2617 45 9 240 \n", "9 2612 59 10 247 \n", "\n", " vertical_distance_to_hydrology horizontal_distance_to_roadways \\\n", "0 0 510 \n", "1 -6 390 \n", "2 65 3180 \n", "3 118 3090 \n", "4 -1 391 \n", "5 -15 67 \n", "6 5 633 \n", "7 7 573 \n", "8 56 666 \n", "9 11 636 \n", "\n", " hillshade_9am hillshade_noon hillshade_3pm \\\n", "0 221 232 148 \n", "1 220 235 151 \n", "2 234 238 135 \n", "3 238 238 122 \n", "4 220 234 150 \n", "5 230 237 140 \n", "6 222 225 138 \n", "7 222 230 144 \n", "8 223 221 133 \n", "9 228 219 124 \n", "\n", " horizontal_distance_to_fire_points wilderness_area soil_type \\\n", "0 6279 rawah wilderness area Soil_Type_7745 \n", "1 6225 rawah wilderness area Soil_Type_7745 \n", "2 6121 rawah wilderness area Soil_Type_4744 \n", "3 6211 rawah wilderness area Soil_Type_7746 \n", "4 6172 rawah wilderness area Soil_Type_7745 \n", "5 6031 rawah wilderness area Soil_Type_7745 \n", "6 6256 rawah wilderness area Soil_Type_7745 \n", "7 6228 rawah wilderness area Soil_Type_7745 \n", "8 6244 rawah wilderness area Soil_Type_7745 \n", "9 6230 rawah wilderness area Soil_Type_7745 \n", "\n", " cover_type wilderness_area_indexed soil_type_indexed \\\n", "0 Aspen 0.0 0.0 \n", "1 Aspen 0.0 0.0 \n", "2 Lodgepole Pine 0.0 7.0 \n", "3 Lodgepole Pine 0.0 6.0 \n", "4 Aspen 0.0 0.0 \n", "5 Lodgepole Pine 0.0 0.0 \n", "6 Aspen 0.0 0.0 \n", "7 Aspen 0.0 0.0 \n", "8 Aspen 0.0 0.0 \n", "9 Aspen 0.0 0.0 \n", "\n", " cover_type_indexed \n", "0 5.0 \n", "1 5.0 \n", "2 0.0 \n", "3 0.0 \n", "4 5.0 \n", "5 0.0 \n", "6 5.0 \n", "7 5.0 \n", "8 5.0 \n", "9 5.0 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "covertype_dataset_indexed_features.limit(10).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we just have to `VectorAssemble` our features to create the feature vector:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from pyspark.ml.feature import VectorAssembler\n", "\n", "feature_columns = [\"elevation\",\n", " \"aspect\",\n", " \"slope\",\n", " \"horizontal_distance_to_hydrology\",\n", " \"vertical_distance_to_hydrology\",\n", " \"horizontal_distance_to_roadways\",\n", " \"hillshade_9am\",\n", " \"hillshade_noon\",\n", " \"hillshade_3pm\",\n", " \"horizontal_distance_to_fire_points\",\n", " \"wilderness_area_indexed\",\n", " \"soil_type_indexed\"]\n", "\n", "feature_assembler = VectorAssembler(inputCols=feature_columns, outputCol=\"features\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we have our dataset prepared for ML:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "covertype_dataset_prepared = feature_assembler.transform(covertype_dataset_indexed_features)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- elevation: long (nullable = true)\n", " |-- aspect: long (nullable = true)\n", " |-- slope: long (nullable = true)\n", " |-- horizontal_distance_to_hydrology: long (nullable = true)\n", " |-- vertical_distance_to_hydrology: long (nullable = true)\n", " |-- horizontal_distance_to_roadways: long (nullable = true)\n", " |-- hillshade_9am: long (nullable = true)\n", " |-- hillshade_noon: long (nullable = true)\n", " |-- hillshade_3pm: long (nullable = true)\n", " |-- horizontal_distance_to_fire_points: long (nullable = true)\n", " |-- wilderness_area: string (nullable = true)\n", " |-- soil_type: string (nullable = true)\n", " |-- cover_type: string (nullable = true)\n", " |-- wilderness_area_indexed: double (nullable = false)\n", " |-- soil_type_indexed: double (nullable = false)\n", " |-- cover_type_indexed: double (nullable = false)\n", " |-- features: vector (nullable = true)\n", "\n" ] } ], "source": [ "covertype_dataset_prepared.printSchema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's build a simple `pyspark.ml.classification.DecisionTreeClassifier`:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# from pyspark.ml.classification import DecisionTreeClassifier\n", "\n", "# dtree = DecisionTreeClassifier(featuresCol=\"features\",\n", "# labelCol=\"cover_type_indexed\",\n", "# maxDepth=3,\n", "# maxBins=50)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We fit it, and we get our `DecisionTreeClassificationModel`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "DecisionTreeClassificationModel (uid=DecisionTreeClassifier_99684a674979) of depth 3 with 11 nodes" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# dtree_model = dtree.fit(covertype_dataset_prepared)\n", "\n", "# dtree_model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pyspark.ml.classification import DecisionTreeClassificationModel" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "dtree_model = DecisionTreeClassificationModel.load('tree_model')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.toDebugString` attribute prints the decision rules for the tree, but it is not very user-friendly:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f84d275537cd) of depth 3 with 7 nodes\n", " If (feature 0 <= 3050.5)\n", " If (feature 0 <= 2540.5)\n", " If (feature 10 in {0.0})\n", " Predict: 0.0\n", " Else (feature 10 not in {0.0})\n", " Predict: 2.0\n", " Else (feature 0 > 2540.5)\n", " Predict: 0.0\n", " Else (feature 0 > 3050.5)\n", " Predict: 1.0\n", "\n" ] } ], "source": [ "print(dtree_model.toDebugString)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Perhaps `spark_tree_plotting` may be helpful here ;)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "ename": "TypeError", "evalue": "isinstance() arg 2 must be a type or tuple of types", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mfilled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# With color!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mroundedCorners\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Rounded corners in the nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mroundLeaves\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m \u001b[0;31m# Leaves will be ellipses instead of rectangles\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m )\n", "\u001b[0;32m~/wa/spark_wa/spark-tree-plotting/python/spark_tree_plotting.py\u001b[0m in \u001b[0;36mplot_tree\u001b[0;34m(DecisionTreeClassificationModel, featureNames, categoryNames, classNames, filled, roundedCorners, roundLeaves)\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0mfilled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilled\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[0mroundedCorners\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroundedCorners\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 436\u001b[0;31m \u001b[0mroundLeaves\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroundLeaves\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 437\u001b[0m )\n\u001b[1;32m 438\u001b[0m )\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/__init__.py\u001b[0m in \u001b[0;36mgraph_from_dot_data\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 216\u001b[0m \"\"\"\n\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdot_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_dot_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/dot_parser.py\u001b[0m in \u001b[0;36mparse_dot_data\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 521\u001b[0m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraphparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparseString\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 522\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokens\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36mparseString\u001b[0;34m(self, instring, parseAll)\u001b[0m\n\u001b[1;32m 1816\u001b[0m \u001b[0minstring\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpandtabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1817\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1818\u001b[0;31m \u001b[0mloc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parse\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1819\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mparseAll\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1820\u001b[0m \u001b[0mloc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreParse\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36m_parseNoCache\u001b[0;34m(self, instring, loc, doActions, callPreParse)\u001b[0m\n\u001b[1;32m 1593\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparseAction\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1594\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1595\u001b[0;31m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokensStart\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretTokens\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1596\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mIndexError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mparse_action_exc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1597\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParseException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"exception raised in parse action\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1217\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlimit\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1218\u001b[0m \u001b[0mfoundArity\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/dot_parser.py\u001b[0m in \u001b[0;36mpush_top_graph_stmt\u001b[0;34m(str, loc, toks)\u001b[0m\n\u001b[1;32m 79\u001b[0m if( isinstance(element, (ParseResults, tuple, list)) and\n\u001b[1;32m 80\u001b[0m len(element) == 1 and isinstance(element[0], str) ):\n\u001b[0;32m---> 81\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 82\u001b[0m \u001b[0melement\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melement\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: isinstance() arg 2 must be a type or tuple of types" ] } ], "source": [ "import sys\n", "sys.path.insert(0, '/Users/per0/wa/spark_wa/spark-tree-plotting/python')\n", "from spark_tree_plotting import plot_tree\n", "\n", "tree_plot = plot_tree(dtree_model,\n", " featureNames=feature_columns,\n", " categoryNames={\"wilderness_area_indexed\":string_indexer_wilderness_model.labels,\n", " \"soil_type_indexed\":string_indexer_soil_model.labels},\n", " classNames=string_indexer_cover_model.labels,\n", " filled=True, # With color!\n", " roundedCorners=True, # Rounded corners in the nodes\n", " roundLeaves=True # Leaves will be ellipses instead of rectangles\n", " )" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "json_tree = sc._jvm.com.vfive.spark.ml.SparkMLTree(dtree_model._java_obj)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"featureIndex\":0,\n", " \"gain\":0.08681394658400207,\n", " \"impurity\":0.6230942824070332,\n", " \"threshold\":3050.5,\n", " \"nodeType\":\"internal\",\n", " \"splitType\":\"continuous\",\n", " \"prediction\":0.0,\n", " \"leftChild\":{\n", " \"featureIndex\":0,\n", " \"gain\":0.08616165361635758,\n", " \"impurity\":0.5539261911259398,\n", " \"threshold\":2540.5,\n", " \"nodeType\":\"internal\",\n", " \"splitType\":\"continuous\",\n", " \"prediction\":0.0,\n", " \"leftChild\":{\n", " \"featureIndex\":10,\n", " \"gain\":0.04640621444482429,\n", " \"impurity\":0.6171371727013576,\n", " \"nodeType\":\"internal\",\n", " \"splitType\":\"categorical\",\n", " \"leftCategories\":[\n", " 0.0\n", " ],\n", " \"rightCategories\":[\n", " 1.0,\n", " 2.0,\n", " 3.0\n", " ],\n", " \"prediction\":2.0,\n", " \"leftChild\":{\n", " \"impurity\":0.18642232564845895,\n", " \"nodeType\":\"leaf\",\n", " \"prediction\":0.0\n", " },\n", " \"rightChild\":{\n", " \"impurity\":0.5893401621499551,\n", " \"nodeType\":\"leaf\",\n", " \"prediction\":2.0\n", " }\n", " },\n", " \"rightChild\":{\n", " \"impurity\":0.4430125702798494,\n", " \"nodeType\":\"leaf\",\n", " \"prediction\":0.0\n", " }\n", " },\n", " \"rightChild\":{\n", " \"impurity\":0.5109863148417016,\n", " \"nodeType\":\"leaf\",\n", " \"prediction\":1.0\n", " }\n", "}\n" ] } ], "source": [ "print(json_tree.toJsonPlotFormat())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "Image(tree_plot)" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }