{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Example: Covertype Data Set"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following example uses the (processed) Covertype dataset from [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Covertype).\n",
"\n",
"It is a dataset with both categorical (`wilderness_area` and `soil_type`) and continuous (the rest) features. The target is the `cover_type` column:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark import SparkConf, SparkContext"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"conf = SparkConf() #.set(\"spark.jars\", \"/Users/per0/wa/spark_wa/spark-tree-plotting/target/scala-2.11/spark-tree-plotting_0.2.jar\")\n",
"\n",
"sc = SparkContext.getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- elevation: long (nullable = true)\n",
" |-- aspect: long (nullable = true)\n",
" |-- slope: long (nullable = true)\n",
" |-- horizontal_distance_to_hydrology: long (nullable = true)\n",
" |-- vertical_distance_to_hydrology: long (nullable = true)\n",
" |-- horizontal_distance_to_roadways: long (nullable = true)\n",
" |-- hillshade_9am: long (nullable = true)\n",
" |-- hillshade_noon: long (nullable = true)\n",
" |-- hillshade_3pm: long (nullable = true)\n",
" |-- horizontal_distance_to_fire_points: long (nullable = true)\n",
" |-- wilderness_area: string (nullable = true)\n",
" |-- soil_type: string (nullable = true)\n",
" |-- cover_type: string (nullable = true)\n",
"\n"
]
}
],
"source": [
"covertype_dataset = spark.read.parquet(\"covertype_dataset.snappy.parquet\")\n",
"\n",
"covertype_dataset.printSchema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The 10 first rows:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" elevation | \n",
" aspect | \n",
" slope | \n",
" horizontal_distance_to_hydrology | \n",
" vertical_distance_to_hydrology | \n",
" horizontal_distance_to_roadways | \n",
" hillshade_9am | \n",
" hillshade_noon | \n",
" hillshade_3pm | \n",
" horizontal_distance_to_fire_points | \n",
" wilderness_area | \n",
" soil_type | \n",
" cover_type | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2596 | \n",
" 51 | \n",
" 3 | \n",
" 258 | \n",
" 0 | \n",
" 510 | \n",
" 221 | \n",
" 232 | \n",
" 148 | \n",
" 6279 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 1 | \n",
" 2590 | \n",
" 56 | \n",
" 2 | \n",
" 212 | \n",
" -6 | \n",
" 390 | \n",
" 220 | \n",
" 235 | \n",
" 151 | \n",
" 6225 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 2 | \n",
" 2804 | \n",
" 139 | \n",
" 9 | \n",
" 268 | \n",
" 65 | \n",
" 3180 | \n",
" 234 | \n",
" 238 | \n",
" 135 | \n",
" 6121 | \n",
" rawah wilderness area | \n",
" Soil_Type_4744 | \n",
" Lodgepole Pine | \n",
"
\n",
" \n",
" 3 | \n",
" 2785 | \n",
" 155 | \n",
" 18 | \n",
" 242 | \n",
" 118 | \n",
" 3090 | \n",
" 238 | \n",
" 238 | \n",
" 122 | \n",
" 6211 | \n",
" rawah wilderness area | \n",
" Soil_Type_7746 | \n",
" Lodgepole Pine | \n",
"
\n",
" \n",
" 4 | \n",
" 2595 | \n",
" 45 | \n",
" 2 | \n",
" 153 | \n",
" -1 | \n",
" 391 | \n",
" 220 | \n",
" 234 | \n",
" 150 | \n",
" 6172 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 5 | \n",
" 2579 | \n",
" 132 | \n",
" 6 | \n",
" 300 | \n",
" -15 | \n",
" 67 | \n",
" 230 | \n",
" 237 | \n",
" 140 | \n",
" 6031 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Lodgepole Pine | \n",
"
\n",
" \n",
" 6 | \n",
" 2606 | \n",
" 45 | \n",
" 7 | \n",
" 270 | \n",
" 5 | \n",
" 633 | \n",
" 222 | \n",
" 225 | \n",
" 138 | \n",
" 6256 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 7 | \n",
" 2605 | \n",
" 49 | \n",
" 4 | \n",
" 234 | \n",
" 7 | \n",
" 573 | \n",
" 222 | \n",
" 230 | \n",
" 144 | \n",
" 6228 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 8 | \n",
" 2617 | \n",
" 45 | \n",
" 9 | \n",
" 240 | \n",
" 56 | \n",
" 666 | \n",
" 223 | \n",
" 221 | \n",
" 133 | \n",
" 6244 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
" 9 | \n",
" 2612 | \n",
" 59 | \n",
" 10 | \n",
" 247 | \n",
" 11 | \n",
" 636 | \n",
" 228 | \n",
" 219 | \n",
" 124 | \n",
" 6230 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" elevation aspect slope horizontal_distance_to_hydrology \\\n",
"0 2596 51 3 258 \n",
"1 2590 56 2 212 \n",
"2 2804 139 9 268 \n",
"3 2785 155 18 242 \n",
"4 2595 45 2 153 \n",
"5 2579 132 6 300 \n",
"6 2606 45 7 270 \n",
"7 2605 49 4 234 \n",
"8 2617 45 9 240 \n",
"9 2612 59 10 247 \n",
"\n",
" vertical_distance_to_hydrology horizontal_distance_to_roadways \\\n",
"0 0 510 \n",
"1 -6 390 \n",
"2 65 3180 \n",
"3 118 3090 \n",
"4 -1 391 \n",
"5 -15 67 \n",
"6 5 633 \n",
"7 7 573 \n",
"8 56 666 \n",
"9 11 636 \n",
"\n",
" hillshade_9am hillshade_noon hillshade_3pm \\\n",
"0 221 232 148 \n",
"1 220 235 151 \n",
"2 234 238 135 \n",
"3 238 238 122 \n",
"4 220 234 150 \n",
"5 230 237 140 \n",
"6 222 225 138 \n",
"7 222 230 144 \n",
"8 223 221 133 \n",
"9 228 219 124 \n",
"\n",
" horizontal_distance_to_fire_points wilderness_area soil_type \\\n",
"0 6279 rawah wilderness area Soil_Type_7745 \n",
"1 6225 rawah wilderness area Soil_Type_7745 \n",
"2 6121 rawah wilderness area Soil_Type_4744 \n",
"3 6211 rawah wilderness area Soil_Type_7746 \n",
"4 6172 rawah wilderness area Soil_Type_7745 \n",
"5 6031 rawah wilderness area Soil_Type_7745 \n",
"6 6256 rawah wilderness area Soil_Type_7745 \n",
"7 6228 rawah wilderness area Soil_Type_7745 \n",
"8 6244 rawah wilderness area Soil_Type_7745 \n",
"9 6230 rawah wilderness area Soil_Type_7745 \n",
"\n",
" cover_type \n",
"0 Aspen \n",
"1 Aspen \n",
"2 Lodgepole Pine \n",
"3 Lodgepole Pine \n",
"4 Aspen \n",
"5 Lodgepole Pine \n",
"6 Aspen \n",
"7 Aspen \n",
"8 Aspen \n",
"9 Aspen "
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"covertype_dataset.limit(10).toPandas()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order for Spark's `DecisionTreeClassifier` to work with the categorical features (as well as the target), we first need to use [`pyspark.ml.feature.StringIndexer`](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer)s to generate a numeric representation for those columns:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.ml.feature import StringIndexer\n",
"\n",
"string_indexer_wilderness = StringIndexer(inputCol=\"wilderness_area\", outputCol=\"wilderness_area_indexed\")\n",
"\n",
"string_indexer_soil = StringIndexer(inputCol=\"soil_type\", outputCol=\"soil_type_indexed\")\n",
" \n",
"string_indexer_cover = StringIndexer(inputCol=\"cover_type\", outputCol=\"cover_type_indexed\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To generate the new *StringIndexerModels*, we call `.fit()` on each `StringIndexer` instance:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"string_indexer_wilderness_model = string_indexer_wilderness.fit(covertype_dataset)\n",
"\n",
"string_indexer_soil_model = string_indexer_soil.fit(covertype_dataset)\n",
"\n",
"string_indexer_cover_model = string_indexer_cover.fit(covertype_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we create the new columns:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"covertype_dataset_indexed_features = string_indexer_cover_model.transform(string_indexer_soil_model\n",
" .transform(string_indexer_wilderness_model\n",
" .transform(covertype_dataset)\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"New columns can be seen at the right:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" elevation | \n",
" aspect | \n",
" slope | \n",
" horizontal_distance_to_hydrology | \n",
" vertical_distance_to_hydrology | \n",
" horizontal_distance_to_roadways | \n",
" hillshade_9am | \n",
" hillshade_noon | \n",
" hillshade_3pm | \n",
" horizontal_distance_to_fire_points | \n",
" wilderness_area | \n",
" soil_type | \n",
" cover_type | \n",
" wilderness_area_indexed | \n",
" soil_type_indexed | \n",
" cover_type_indexed | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2596 | \n",
" 51 | \n",
" 3 | \n",
" 258 | \n",
" 0 | \n",
" 510 | \n",
" 221 | \n",
" 232 | \n",
" 148 | \n",
" 6279 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 2590 | \n",
" 56 | \n",
" 2 | \n",
" 212 | \n",
" -6 | \n",
" 390 | \n",
" 220 | \n",
" 235 | \n",
" 151 | \n",
" 6225 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 2804 | \n",
" 139 | \n",
" 9 | \n",
" 268 | \n",
" 65 | \n",
" 3180 | \n",
" 234 | \n",
" 238 | \n",
" 135 | \n",
" 6121 | \n",
" rawah wilderness area | \n",
" Soil_Type_4744 | \n",
" Lodgepole Pine | \n",
" 0.0 | \n",
" 7.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 2785 | \n",
" 155 | \n",
" 18 | \n",
" 242 | \n",
" 118 | \n",
" 3090 | \n",
" 238 | \n",
" 238 | \n",
" 122 | \n",
" 6211 | \n",
" rawah wilderness area | \n",
" Soil_Type_7746 | \n",
" Lodgepole Pine | \n",
" 0.0 | \n",
" 6.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 2595 | \n",
" 45 | \n",
" 2 | \n",
" 153 | \n",
" -1 | \n",
" 391 | \n",
" 220 | \n",
" 234 | \n",
" 150 | \n",
" 6172 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 5 | \n",
" 2579 | \n",
" 132 | \n",
" 6 | \n",
" 300 | \n",
" -15 | \n",
" 67 | \n",
" 230 | \n",
" 237 | \n",
" 140 | \n",
" 6031 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Lodgepole Pine | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 6 | \n",
" 2606 | \n",
" 45 | \n",
" 7 | \n",
" 270 | \n",
" 5 | \n",
" 633 | \n",
" 222 | \n",
" 225 | \n",
" 138 | \n",
" 6256 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 7 | \n",
" 2605 | \n",
" 49 | \n",
" 4 | \n",
" 234 | \n",
" 7 | \n",
" 573 | \n",
" 222 | \n",
" 230 | \n",
" 144 | \n",
" 6228 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 8 | \n",
" 2617 | \n",
" 45 | \n",
" 9 | \n",
" 240 | \n",
" 56 | \n",
" 666 | \n",
" 223 | \n",
" 221 | \n",
" 133 | \n",
" 6244 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
" 9 | \n",
" 2612 | \n",
" 59 | \n",
" 10 | \n",
" 247 | \n",
" 11 | \n",
" 636 | \n",
" 228 | \n",
" 219 | \n",
" 124 | \n",
" 6230 | \n",
" rawah wilderness area | \n",
" Soil_Type_7745 | \n",
" Aspen | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" elevation aspect slope horizontal_distance_to_hydrology \\\n",
"0 2596 51 3 258 \n",
"1 2590 56 2 212 \n",
"2 2804 139 9 268 \n",
"3 2785 155 18 242 \n",
"4 2595 45 2 153 \n",
"5 2579 132 6 300 \n",
"6 2606 45 7 270 \n",
"7 2605 49 4 234 \n",
"8 2617 45 9 240 \n",
"9 2612 59 10 247 \n",
"\n",
" vertical_distance_to_hydrology horizontal_distance_to_roadways \\\n",
"0 0 510 \n",
"1 -6 390 \n",
"2 65 3180 \n",
"3 118 3090 \n",
"4 -1 391 \n",
"5 -15 67 \n",
"6 5 633 \n",
"7 7 573 \n",
"8 56 666 \n",
"9 11 636 \n",
"\n",
" hillshade_9am hillshade_noon hillshade_3pm \\\n",
"0 221 232 148 \n",
"1 220 235 151 \n",
"2 234 238 135 \n",
"3 238 238 122 \n",
"4 220 234 150 \n",
"5 230 237 140 \n",
"6 222 225 138 \n",
"7 222 230 144 \n",
"8 223 221 133 \n",
"9 228 219 124 \n",
"\n",
" horizontal_distance_to_fire_points wilderness_area soil_type \\\n",
"0 6279 rawah wilderness area Soil_Type_7745 \n",
"1 6225 rawah wilderness area Soil_Type_7745 \n",
"2 6121 rawah wilderness area Soil_Type_4744 \n",
"3 6211 rawah wilderness area Soil_Type_7746 \n",
"4 6172 rawah wilderness area Soil_Type_7745 \n",
"5 6031 rawah wilderness area Soil_Type_7745 \n",
"6 6256 rawah wilderness area Soil_Type_7745 \n",
"7 6228 rawah wilderness area Soil_Type_7745 \n",
"8 6244 rawah wilderness area Soil_Type_7745 \n",
"9 6230 rawah wilderness area Soil_Type_7745 \n",
"\n",
" cover_type wilderness_area_indexed soil_type_indexed \\\n",
"0 Aspen 0.0 0.0 \n",
"1 Aspen 0.0 0.0 \n",
"2 Lodgepole Pine 0.0 7.0 \n",
"3 Lodgepole Pine 0.0 6.0 \n",
"4 Aspen 0.0 0.0 \n",
"5 Lodgepole Pine 0.0 0.0 \n",
"6 Aspen 0.0 0.0 \n",
"7 Aspen 0.0 0.0 \n",
"8 Aspen 0.0 0.0 \n",
"9 Aspen 0.0 0.0 \n",
"\n",
" cover_type_indexed \n",
"0 5.0 \n",
"1 5.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 5.0 \n",
"5 0.0 \n",
"6 5.0 \n",
"7 5.0 \n",
"8 5.0 \n",
"9 5.0 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"covertype_dataset_indexed_features.limit(10).toPandas()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we just have to `VectorAssemble` our features to create the feature vector:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pyspark.ml.feature import VectorAssembler\n",
"\n",
"feature_columns = [\"elevation\",\n",
" \"aspect\",\n",
" \"slope\",\n",
" \"horizontal_distance_to_hydrology\",\n",
" \"vertical_distance_to_hydrology\",\n",
" \"horizontal_distance_to_roadways\",\n",
" \"hillshade_9am\",\n",
" \"hillshade_noon\",\n",
" \"hillshade_3pm\",\n",
" \"horizontal_distance_to_fire_points\",\n",
" \"wilderness_area_indexed\",\n",
" \"soil_type_indexed\"]\n",
"\n",
"feature_assembler = VectorAssembler(inputCols=feature_columns, outputCol=\"features\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we have our dataset prepared for ML:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"covertype_dataset_prepared = feature_assembler.transform(covertype_dataset_indexed_features)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- elevation: long (nullable = true)\n",
" |-- aspect: long (nullable = true)\n",
" |-- slope: long (nullable = true)\n",
" |-- horizontal_distance_to_hydrology: long (nullable = true)\n",
" |-- vertical_distance_to_hydrology: long (nullable = true)\n",
" |-- horizontal_distance_to_roadways: long (nullable = true)\n",
" |-- hillshade_9am: long (nullable = true)\n",
" |-- hillshade_noon: long (nullable = true)\n",
" |-- hillshade_3pm: long (nullable = true)\n",
" |-- horizontal_distance_to_fire_points: long (nullable = true)\n",
" |-- wilderness_area: string (nullable = true)\n",
" |-- soil_type: string (nullable = true)\n",
" |-- cover_type: string (nullable = true)\n",
" |-- wilderness_area_indexed: double (nullable = false)\n",
" |-- soil_type_indexed: double (nullable = false)\n",
" |-- cover_type_indexed: double (nullable = false)\n",
" |-- features: vector (nullable = true)\n",
"\n"
]
}
],
"source": [
"covertype_dataset_prepared.printSchema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a simple `pyspark.ml.classification.DecisionTreeClassifier`:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# from pyspark.ml.classification import DecisionTreeClassifier\n",
"\n",
"# dtree = DecisionTreeClassifier(featuresCol=\"features\",\n",
"# labelCol=\"cover_type_indexed\",\n",
"# maxDepth=3,\n",
"# maxBins=50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We fit it, and we get our `DecisionTreeClassificationModel`:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassificationModel (uid=DecisionTreeClassifier_99684a674979) of depth 3 with 11 nodes"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# dtree_model = dtree.fit(covertype_dataset_prepared)\n",
"\n",
"# dtree_model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml.classification import DecisionTreeClassificationModel"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dtree_model = DecisionTreeClassificationModel.load('tree_model')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.toDebugString` attribute prints the decision rules for the tree, but it is not very user-friendly:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f84d275537cd) of depth 3 with 7 nodes\n",
" If (feature 0 <= 3050.5)\n",
" If (feature 0 <= 2540.5)\n",
" If (feature 10 in {0.0})\n",
" Predict: 0.0\n",
" Else (feature 10 not in {0.0})\n",
" Predict: 2.0\n",
" Else (feature 0 > 2540.5)\n",
" Predict: 0.0\n",
" Else (feature 0 > 3050.5)\n",
" Predict: 1.0\n",
"\n"
]
}
],
"source": [
"print(dtree_model.toDebugString)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Perhaps `spark_tree_plotting` may be helpful here ;)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "TypeError",
"evalue": "isinstance() arg 2 must be a type or tuple of types",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mfilled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# With color!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mroundedCorners\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Rounded corners in the nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mroundLeaves\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m \u001b[0;31m# Leaves will be ellipses instead of rectangles\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m )\n",
"\u001b[0;32m~/wa/spark_wa/spark-tree-plotting/python/spark_tree_plotting.py\u001b[0m in \u001b[0;36mplot_tree\u001b[0;34m(DecisionTreeClassificationModel, featureNames, categoryNames, classNames, filled, roundedCorners, roundLeaves)\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0mfilled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfilled\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[0mroundedCorners\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroundedCorners\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 436\u001b[0;31m \u001b[0mroundLeaves\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroundLeaves\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 437\u001b[0m )\n\u001b[1;32m 438\u001b[0m )\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/__init__.py\u001b[0m in \u001b[0;36mgraph_from_dot_data\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 216\u001b[0m \"\"\"\n\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdot_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_dot_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/dot_parser.py\u001b[0m in \u001b[0;36mparse_dot_data\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 521\u001b[0m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraphparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparseString\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 522\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokens\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36mparseString\u001b[0;34m(self, instring, parseAll)\u001b[0m\n\u001b[1;32m 1816\u001b[0m \u001b[0minstring\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpandtabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1817\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1818\u001b[0;31m \u001b[0mloc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parse\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1819\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mparseAll\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1820\u001b[0m \u001b[0mloc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreParse\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36m_parseNoCache\u001b[0;34m(self, instring, loc, doActions, callPreParse)\u001b[0m\n\u001b[1;32m 1593\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparseAction\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1594\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1595\u001b[0;31m \u001b[0mtokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0minstring\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokensStart\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretTokens\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1596\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mIndexError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mparse_action_exc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1597\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParseException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"exception raised in parse action\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pyparsing.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1217\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlimit\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1218\u001b[0m \u001b[0mfoundArity\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/anaconda3/envs/an37/lib/python3.7/site-packages/pydot/dot_parser.py\u001b[0m in \u001b[0;36mpush_top_graph_stmt\u001b[0;34m(str, loc, toks)\u001b[0m\n\u001b[1;32m 79\u001b[0m if( isinstance(element, (ParseResults, tuple, list)) and\n\u001b[1;32m 80\u001b[0m len(element) == 1 and isinstance(element[0], str) ):\n\u001b[0;32m---> 81\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 82\u001b[0m \u001b[0melement\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melement\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: isinstance() arg 2 must be a type or tuple of types"
]
}
],
"source": [
"import sys\n",
"sys.path.insert(0, '/Users/per0/wa/spark_wa/spark-tree-plotting/python')\n",
"from spark_tree_plotting import plot_tree\n",
"\n",
"tree_plot = plot_tree(dtree_model,\n",
" featureNames=feature_columns,\n",
" categoryNames={\"wilderness_area_indexed\":string_indexer_wilderness_model.labels,\n",
" \"soil_type_indexed\":string_indexer_soil_model.labels},\n",
" classNames=string_indexer_cover_model.labels,\n",
" filled=True, # With color!\n",
" roundedCorners=True, # Rounded corners in the nodes\n",
" roundLeaves=True # Leaves will be ellipses instead of rectangles\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"json_tree = sc._jvm.com.vfive.spark.ml.SparkMLTree(dtree_model._java_obj)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"featureIndex\":0,\n",
" \"gain\":0.08681394658400207,\n",
" \"impurity\":0.6230942824070332,\n",
" \"threshold\":3050.5,\n",
" \"nodeType\":\"internal\",\n",
" \"splitType\":\"continuous\",\n",
" \"prediction\":0.0,\n",
" \"leftChild\":{\n",
" \"featureIndex\":0,\n",
" \"gain\":0.08616165361635758,\n",
" \"impurity\":0.5539261911259398,\n",
" \"threshold\":2540.5,\n",
" \"nodeType\":\"internal\",\n",
" \"splitType\":\"continuous\",\n",
" \"prediction\":0.0,\n",
" \"leftChild\":{\n",
" \"featureIndex\":10,\n",
" \"gain\":0.04640621444482429,\n",
" \"impurity\":0.6171371727013576,\n",
" \"nodeType\":\"internal\",\n",
" \"splitType\":\"categorical\",\n",
" \"leftCategories\":[\n",
" 0.0\n",
" ],\n",
" \"rightCategories\":[\n",
" 1.0,\n",
" 2.0,\n",
" 3.0\n",
" ],\n",
" \"prediction\":2.0,\n",
" \"leftChild\":{\n",
" \"impurity\":0.18642232564845895,\n",
" \"nodeType\":\"leaf\",\n",
" \"prediction\":0.0\n",
" },\n",
" \"rightChild\":{\n",
" \"impurity\":0.5893401621499551,\n",
" \"nodeType\":\"leaf\",\n",
" \"prediction\":2.0\n",
" }\n",
" },\n",
" \"rightChild\":{\n",
" \"impurity\":0.4430125702798494,\n",
" \"nodeType\":\"leaf\",\n",
" \"prediction\":0.0\n",
" }\n",
" },\n",
" \"rightChild\":{\n",
" \"impurity\":0.5109863148417016,\n",
" \"nodeType\":\"leaf\",\n",
" \"prediction\":1.0\n",
" }\n",
"}\n"
]
}
],
"source": [
"print(json_tree.toJsonPlotFormat())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"\n",
"Image(tree_plot)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}