{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Understanding the decision tree structure\n\nThe decision tree structure can be analysed to gain further insight on the\nrelation between the features and the target to predict. In this example, we\nshow how to retrieve:\n\n- the binary tree structure;\n- the depth of each node and whether or not it's a leaf;\n- the nodes that were reached by a sample using the ``decision_path`` method;\n- the leaf that was reached by a sample using the apply method;\n- the rules that were used to predict a sample;\n- the decision path shared by a group of samples.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom sklearn import tree\nfrom sklearn.datasets import load_iris\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train tree classifier\nFirst, we fit a :class:`~sklearn.tree.DecisionTreeClassifier` using the\n:func:`~sklearn.datasets.load_iris` dataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "iris = load_iris()\nX = iris.data\ny = iris.target\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\nclf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)\nclf.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tree structure\n\nThe decision classifier has an attribute called ``tree_`` which allows access\nto low level attributes such as ``node_count``, the total number of nodes,\nand ``max_depth``, the maximal depth of the tree. The\n``tree_.compute_node_depths()`` method computes the depth of each node in the\ntree. `tree_` also stores the entire binary tree structure, represented as a\nnumber of parallel arrays. The i-th element of each array holds information\nabout the node ``i``. Node 0 is the tree's root. Some of the arrays only\napply to either leaves or split nodes. In this case the values of the nodes\nof the other type is arbitrary. For example, the arrays ``feature`` and\n``threshold`` only apply to split nodes. The values for leaf nodes in these\narrays are therefore arbitrary.\n\nAmong these arrays, we have:\n\n- ``children_left[i]``: id of the left child of node ``i`` or -1 if leaf node\n- ``children_right[i]``: id of the right child of node ``i`` or -1 if leaf node\n- ``feature[i]``: feature used for splitting node ``i``\n- ``threshold[i]``: threshold value at node ``i``\n- ``n_node_samples[i]``: the number of training samples reaching node ``i``\n- ``impurity[i]``: the impurity at node ``i``\n- ``weighted_n_node_samples[i]``: the weighted number of training samples\n reaching node ``i``\n- ``value[i, j, k]``: the summary of the training samples that reached node i for\n output j and class k (for regression tree, class is set to 1). See below\n for more information about ``value``.\n\nUsing the arrays, we can traverse the tree structure to compute various\nproperties. Below, we will compute the depth of each node and whether or not\nit is a leaf.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_nodes = clf.tree_.node_count\nchildren_left = clf.tree_.children_left\nchildren_right = clf.tree_.children_right\nfeature = clf.tree_.feature\nthreshold = clf.tree_.threshold\nvalues = clf.tree_.value\n\nnode_depth = np.zeros(shape=n_nodes, dtype=np.int64)\nis_leaves = np.zeros(shape=n_nodes, dtype=bool)\nstack = [(0, 0)] # start with the root node id (0) and its depth (0)\nwhile len(stack) > 0:\n # `pop` ensures each node is only visited once\n node_id, depth = stack.pop()\n node_depth[node_id] = depth\n\n # If the left and right child of a node is not the same we have a split\n # node\n is_split_node = children_left[node_id] != children_right[node_id]\n # If a split node, append left and right children and depth to `stack`\n # so we can loop through them\n if is_split_node:\n stack.append((children_left[node_id], depth + 1))\n stack.append((children_right[node_id], depth + 1))\n else:\n is_leaves[node_id] = True\n\nprint(\n \"The binary tree structure has {n} nodes and has \"\n \"the following tree structure:\\n\".format(n=n_nodes)\n)\nfor i in range(n_nodes):\n if is_leaves[i]:\n print(\n \"{space}node={node} is a leaf node with value={value}.\".format(\n space=node_depth[i] * \"\\t\", node=i, value=np.around(values[i], 3)\n )\n )\n else:\n print(\n \"{space}node={node} is a split node with value={value}: \"\n \"go to node {left} if X[:, {feature}] <= {threshold} \"\n \"else to node {right}.\".format(\n space=node_depth[i] * \"\\t\",\n node=i,\n left=children_left[i],\n feature=feature[i],\n threshold=threshold[i],\n right=children_right[i],\n value=np.around(values[i], 3),\n )\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is the values array used here?\nThe `tree_.value` array is a 3D array of shape\n[``n_nodes``, ``n_classes``, ``n_outputs``] which provides the proportion of samples\nreaching a node for each class and for each output.\nEach node has a ``value`` array which is the proportion of weighted samples reaching\nthis node for each output and class with respect to the parent node.\n\nOne could convert this to the absolute weighted number of samples reaching a node,\nby multiplying this number by `tree_.weighted_n_node_samples[node_idx]` for the\ngiven node. Note sample weights are not used in this example, so the weighted\nnumber of samples is the number of samples reaching the node because each sample\nhas a weight of 1 by default.\n\nFor example, in the above tree built on the iris dataset, the root node has\n``value = [0.33, 0.304, 0.366]`` indicating there are 33% of class 0 samples,\n30.4% of class 1 samples, and 36.6% of class 2 samples at the root node. One can\nconvert this to the absolute number of samples by multiplying by the number of\nsamples reaching the root node, which is `tree_.weighted_n_node_samples[0]`.\nThen the root node has ``value = [37, 34, 41]``, indicating there are 37 samples\nof class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.\n\nTraversing the tree, the samples are split and as a result, the ``value`` array\nreaching each node changes. The left child of the root node has ``value = [1., 0, 0]``\n(or ``value = [37, 0, 0]`` when converted to the absolute number of samples)\nbecause all 37 samples in the left child node are from class 0.\n\nNote: In this example, `n_outputs=1`, but the tree classifier can also handle\nmulti-output problems. The `value` array at each node would just be a 2D\narray instead.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can compare the above output to the plot of the decision tree.\nHere, we show the proportions of samples of each class that reach each\nnode corresponding to the actual elements of `tree_.value` array.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tree.plot_tree(clf, proportion=True)\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision path\n\nWe can also retrieve the decision path of samples of interest. The\n``decision_path`` method outputs an indicator matrix that allows us to\nretrieve the nodes the samples of interest traverse through. A non zero\nelement in the indicator matrix at position ``(i, j)`` indicates that\nthe sample ``i`` goes through the node ``j``. Or, for one sample ``i``, the\npositions of the non zero elements in row ``i`` of the indicator matrix\ndesignate the ids of the nodes that sample goes through.\n\nThe leaf ids reached by samples of interest can be obtained with the\n``apply`` method. This returns an array of the node ids of the leaves\nreached by each sample of interest. Using the leaf ids and the\n``decision_path`` we can obtain the splitting conditions that were used to\npredict a sample or a group of samples. First, let's do it for one sample.\nNote that ``node_index`` is a sparse matrix.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "node_indicator = clf.decision_path(X_test)\nleaf_id = clf.apply(X_test)\n\nsample_id = 0\n# obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`\nnode_index = node_indicator.indices[\n node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]\n]\n\nprint(\"Rules used to predict sample {id}:\\n\".format(id=sample_id))\nfor node_id in node_index:\n # continue to the next node if it is a leaf node\n if leaf_id[sample_id] == node_id:\n continue\n\n # check if value of the split feature for sample 0 is below threshold\n if X_test[sample_id, feature[node_id]] <= threshold[node_id]:\n threshold_sign = \"<=\"\n else:\n threshold_sign = \">\"\n\n print(\n \"decision node {node} : (X_test[{sample}, {feature}] = {value}) \"\n \"{inequality} {threshold})\".format(\n node=node_id,\n sample=sample_id,\n feature=feature[node_id],\n value=X_test[sample_id, feature[node_id]],\n inequality=threshold_sign,\n threshold=threshold[node_id],\n )\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a group of samples, we can determine the common nodes the samples go\nthrough.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sample_ids = [0, 1]\n# boolean array indicating the nodes both samples go through\ncommon_nodes = node_indicator.toarray()[sample_ids].sum(axis=0) == len(sample_ids)\n# obtain node ids using position in array\ncommon_node_id = np.arange(n_nodes)[common_nodes]\n\nprint(\n \"\\nThe following samples {samples} share the node(s) {nodes} in the tree.\".format(\n samples=sample_ids, nodes=common_node_id\n )\n)\nprint(\"This is {prop}% of all nodes.\".format(prop=100 * len(common_node_id) / n_nodes))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }