{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision tree for classification in plain Python\n", "\n", "\n", "A decision tree is a **supervised** machine learning model that can be used both for **classification** and **regression**. At its core, a decision tree uses a tree structure to predict an output value for a given input example. In the tree, each path from the root node to a leaf node represents a *decision path* that ends in a predicted value. \n", "\n", "A simple example might look as follows:\n", "![caption](figures/decision_tree.png)\n", "\n", "Decision trees have many advantages. For example, they are easy to understand and their decisions are easy to interpret. Also, they don't require a lot of data preparation. A more extensive list of their advantages and disadvantages can be found [here](http://scikit-learn.org/stable/modules/tree.html).\n", "\n", "### CART training algorithm \n", "In order to train a decision tree, various algorithms can be used. In this notebook we will focus on the *CART* algorithm (Classification and Regression Trees) for *classification*. The CART algorithm builds a *binary tree* in which every non-leaf node has exactly two children (corresponding to a yes/no answer). \n", "\n", "Given a set of training examples and their labels, the algorithm repeatedly splits the training examples $D$ into two subsets $D_{left}, D_{right}$ using some feature $f$ and feature threshold $t_f$ such that samples with the same label are grouped together. At each node, the algorithm selects the split $\\theta = (f, t_f)$ that produces the purest subsets, weighted by their size. Purity/impurity is measured using the *Gini impurity*.\n", "\n", "So at each step, the algorithm selects the parameters $\\theta$ that minimize the following cost function:\n", "\n", "\\begin{equation}\n", "J(D, \\theta) = \\frac{n_{left}}{n_{total}} G_{left} + \\frac{n_{right}}{n_{total}} G_{right}\n", "\\end{equation}\n", "\n", "- $D$: remaining training examples \n", "- $n_{total}$ : number of remaining training examples\n", "- $\\theta = (f, t_f)$: feature and feature threshold\n", "- $n_{left}/n_{right}$: number of samples in the left/right subset\n", "- $G_{left}/G_{right}$: Gini impurity of the left/right subset\n", "\n", "This step is repeated recursively until the *maximum allowable depth* is reached or the current number of samples $n_{total}$ drops below some minimum number. The original equations can be found [here](http://scikit-learn.org/stable/modules/tree.html).\n", "\n", "After building the tree, new examples can be classified by navigating through the tree, testing at each node the corresponding feature until a leaf node/prediction is reached.\n", "\n", "\n", "### Gini Impurity\n", "\n", "Given $K$ different classification values $k \\in \\{1, ..., K\\}$ the Gini impurity of node $m$ is computed as follows:\n", "\n", "\\begin{equation}\n", "G_m = 1 - \\sum_{k=1}^{K} (p_{m,k})^2\n", "\\end{equation}\n", "\n", "where $p_{m, k}$ is the fraction of training examples with class $k$ among all training examples in node $m$.\n", "\n", "The Gini impurity can be used to evaluate how good a potential split is. A split divides a given set of training examples into two groups. Gini measures how \"mixed\" the resulting groups are. A perfect separation (i.e. each group contains only samples of the same class) corresponds to a Gini impurity of 0. If the resulting groups contain equally many samples of each class, the Gini impurity will reach its highest value of 0.5\n", "\n", "### Caveats\n", "\n", "Without regularization, decision trees are likely to overfit the training examples. This can be prevented using techniques like *pruning* or by providing a maximum allowed tree depth and/or a minimum number of samples required to split a node further." ] }, { "cell_type": "code", "execution_count": 110, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:18.442511Z", "start_time": "2018-04-09T12:50:18.426963Z" } }, "outputs": [], "source": [ "from sklearn.datasets import load_iris\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "np.random.seed(123)\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "The iris dataset compromises 150 examples of 3 different species of iris flowers (Setosa, Versicolour, and Virginica). Each example is described by four attributes: sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)." ] }, { "cell_type": "code", "execution_count": 111, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:18.959842Z", "start_time": "2018-04-09T12:50:18.951915Z" } }, "outputs": [], "source": [ "iris = load_iris()\n", "\n", "X, y = iris.data, iris.target" ] }, { "cell_type": "code", "execution_count": 112, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:19.202874Z", "start_time": "2018-04-09T12:50:19.188064Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape X_train: (112, 4)\n", "Shape y_train: (112,)\n", "Shape X_test: (38, 4)\n", "Shape y_test: (38,)\n" ] } ], "source": [ "# Split the data into a training and test set\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", "\n", "print(f'Shape X_train: {X_train.shape}')\n", "print(f'Shape y_train: {y_train.shape}')\n", "print(f'Shape X_test: {X_test.shape}')\n", "print(f'Shape y_test: {y_test.shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision tree class\n", "\n", "Parts of this code were inspired by [this tutorial](https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/)" ] }, { "cell_type": "code", "execution_count": 113, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:20.706089Z", "start_time": "2018-04-09T12:50:20.093228Z" } }, "outputs": [], "source": [ "class DecisionTree:\n", " \"\"\"\n", " Decision tree for classification\n", " \"\"\"\n", "\n", " def __init__(self):\n", " self.root_dict = None\n", " self.tree_dict = None\n", "\n", " def split_dataset(self, X, y, feature_idx, threshold):\n", " \"\"\"\n", " Splits dataset X into two subsets, according to a given feature\n", " and feature threshold.\n", "\n", " Args:\n", " X: 2D numpy array with data samples\n", " y: 1D numpy array with labels\n", " feature_idx: int, index of feature used for splitting the data\n", " threshold: float, threshold used for splitting the data\n", "\n", " Returns:\n", " splits: dict containing the left and right subsets\n", " and their labels\n", " \"\"\"\n", "\n", " left_idx = np.where(X[:, feature_idx] < threshold)\n", " right_idx = np.where(X[:, feature_idx] >= threshold)\n", "\n", " left_subset = X[left_idx]\n", " y_left = y[left_idx]\n", "\n", " right_subset = X[right_idx]\n", " y_right = y[right_idx]\n", "\n", " splits = {\n", " 'left': left_subset,\n", " 'y_left': y_left,\n", " 'right': right_subset,\n", " 'y_right': y_right,\n", " }\n", "\n", " return splits\n", "\n", " def gini_impurity(self, y_left, y_right, n_left, n_right):\n", " \"\"\"\n", " Computes Gini impurity of a split.\n", "\n", " Args:\n", " y_left, y_right: target values of samples in left/right subset\n", " n_left, n_right: number of samples in left/right subset\n", "\n", " Returns:\n", " gini_left: float, Gini impurity of left subset\n", " gini_right: gloat, Gini impurity of right subset\n", " \"\"\"\n", "\n", " n_total = n_left + n_right\n", "\n", " score_left, score_right = 0, 0\n", " gini_left, gini_right = 0, 0\n", "\n", " if n_left != 0:\n", " for c in range(self.n_classes):\n", " # For each class c, compute fraction of samples with class c\n", " p_left = len(np.where(y_left == c)[0]) / n_left\n", " score_left += p_left * p_left\n", " gini_left = 1 - score_left\n", "\n", " if n_right != 0:\n", " for c in range(self.n_classes):\n", " p_right = len(np.where(y_right == c)[0]) / n_right\n", " score_right += p_right * p_right\n", " gini_right = 1 - score_right\n", "\n", " return gini_left, gini_right\n", "\n", " def get_cost(self, splits):\n", " \"\"\"\n", " Computes cost of a split given the Gini impurity of\n", " the left and right subset and the sizes of the subsets\n", " \n", " Args:\n", " splits: dict, containing params of current split\n", " \"\"\"\n", " y_left = splits['y_left']\n", " y_right = splits['y_right']\n", "\n", " n_left = len(y_left)\n", " n_right = len(y_right)\n", " n_total = n_left + n_right\n", "\n", " gini_left, gini_right = self.gini_impurity(y_left, y_right, n_left, n_right)\n", " cost = (n_left / n_total) * gini_left + (n_right / n_total) * gini_right\n", "\n", " return cost\n", "\n", " def find_best_split(self, X, y):\n", " \"\"\"\n", " Finds the best feature and feature index to split dataset X into\n", " two groups. Checks every value of every attribute as a candidate\n", " split.\n", "\n", " Args:\n", " X: 2D numpy array with data samples\n", " y: 1D numpy array with labels\n", "\n", " Returns:\n", " best_split_params: dict, containing parameters of the best split\n", " \"\"\"\n", "\n", " n_samples, n_features = X.shape\n", "\n", " best_feature_idx, best_threshold, best_cost, best_splits = np.inf, np.inf, np.inf, None\n", "\n", " for feature_idx in range(n_features):\n", " for i in range(n_samples):\n", " current_sample = X[i]\n", " threshold = current_sample[feature_idx]\n", " splits = self.split_dataset(X, y, feature_idx, threshold)\n", " cost = self.get_cost(splits)\n", "\n", " if cost < best_cost:\n", " best_feature_idx = feature_idx\n", " best_threshold = threshold\n", " best_cost = cost\n", " best_splits = splits\n", "\n", " best_split_params = {\n", " 'feature_idx': best_feature_idx,\n", " 'threshold': best_threshold,\n", " 'cost': best_cost,\n", " 'left': best_splits['left'],\n", " 'y_left': best_splits['y_left'],\n", " 'right': best_splits['right'],\n", " 'y_right': best_splits['y_right'],\n", " }\n", "\n", " return best_split_params\n", "\n", "\n", " def build_tree(self, node_dict, depth, max_depth, min_samples):\n", " \"\"\"\n", " Builds the decision tree in a recursive fashion.\n", "\n", " Args:\n", " node_dict: dict, representing the current node\n", " depth: int, depth of current node in the tree\n", " max_depth: int, maximum allowed tree depth\n", " min_samples: int, minimum number of samples needed to split a node further\n", "\n", " Returns:\n", " node_dict: dict, representing the full subtree originating from current node\n", " \"\"\"\n", " left_samples = node_dict['left']\n", " right_samples = node_dict['right']\n", " y_left_samples = node_dict['y_left']\n", " y_right_samples = node_dict['y_right']\n", "\n", " if len(y_left_samples) == 0 or len(y_right_samples) == 0:\n", " node_dict[\"left_child\"] = node_dict[\"right_child\"] = self.create_terminal_node(np.append(y_left_samples, y_right_samples))\n", " return None\n", "\n", " if depth >= max_depth:\n", " node_dict[\"left_child\"] = self.create_terminal_node(y_left_samples)\n", " node_dict[\"right_child\"] = self.create_terminal_node(y_right_samples)\n", " return None\n", "\n", " if len(right_samples) < min_samples:\n", " node_dict[\"right_child\"] = self.create_terminal_node(y_right_samples)\n", " else:\n", " node_dict[\"right_child\"] = self.find_best_split(right_samples, y_right_samples)\n", " self.build_tree(node_dict[\"right_child\"], depth+1, max_depth, min_samples)\n", "\n", " if len(left_samples) < min_samples:\n", " node_dict[\"left_child\"] = self.create_terminal_node(y_left_samples)\n", " else:\n", " node_dict[\"left_child\"] = self.find_best_split(left_samples, y_left_samples)\n", " self.build_tree(node_dict[\"left_child\"], depth+1, max_depth, min_samples)\n", "\n", " return node_dict\n", "\n", " def create_terminal_node(self, y):\n", " \"\"\"\n", " Creates a terminal node.\n", " Given a set of labels the most common label is computed and\n", " set as the classification value of the node.\n", "\n", " Args:\n", " y: 1D numpy array with labels\n", " Returns:\n", " classification: int, predicted class\n", " \"\"\"\n", " classification = max(set(y), key=list(y).count)\n", " return classification\n", "\n", " def train(self, X, y, max_depth, min_samples):\n", " \"\"\"\n", " Fits decision tree on a given dataset.\n", "\n", " Args:\n", " X: 2D numpy array with data samples\n", " y: 1D numpy array with labels\n", " max_depth: int, maximum allowed tree depth\n", " min_samples: int, minimum number of samples needed to split a node further\n", " \"\"\"\n", " self.n_classes = len(set(y))\n", " self.root_dict = self.find_best_split(X, y)\n", " self.tree_dict = self.build_tree(self.root_dict, 1, max_depth, min_samples)\n", "\n", " def predict(self, X, node):\n", " \"\"\"\n", " Predicts the class for a given input example X.\n", "\n", " Args:\n", " X: 1D numpy array, input example\n", " node: dict, representing trained decision tree\n", "\n", " Returns:\n", " prediction: int, predicted class\n", " \"\"\"\n", " feature_idx = node['feature_idx']\n", " threshold = node['threshold']\n", "\n", " if X[feature_idx] < threshold:\n", " if isinstance(node['left_child'], (int, np.integer)):\n", " return node['left_child']\n", " else:\n", " prediction = self.predict(X, node['left_child'])\n", " elif X[feature_idx] >= threshold:\n", " if isinstance(node['right_child'], (int, np.integer)):\n", " return node['right_child']\n", " else:\n", " prediction = self.predict(X, node['right_child'])\n", "\n", " return prediction" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:19:30.968801Z", "start_time": "2018-04-09T12:19:30.963691Z" } }, "source": [ "## Initializing and training the decision tree" ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:21.291025Z", "start_time": "2018-04-09T12:50:21.235025Z" } }, "outputs": [], "source": [ "tree = DecisionTree()\n", "tree.train(X_train, y_train, max_depth=2, min_samples=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Printing the decision tree structure" ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:24.127183Z", "start_time": "2018-04-09T12:50:24.115505Z" } }, "outputs": [], "source": [ "def print_tree(node, depth=0):\n", " if isinstance(node, (int, np.integer)):\n", " print(f\"{depth * ' '}predicted class: {iris.target_names[node]}\")\n", " else:\n", " print(f\"{depth * ' '}{iris.feature_names[node['feature_idx']]} < {node['threshold']}, \"\n", " f\"cost of split: {round(node['cost'], 3)}\")\n", " print_tree(node[\"left_child\"], depth+1)\n", " print_tree(node[\"right_child\"], depth+1)" ] }, { "cell_type": "code", "execution_count": 116, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:24.818384Z", "start_time": "2018-04-09T12:50:24.809082Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "petal length (cm) < 3.0, cost of split: 0.346\n", " sepal length (cm) < 5.4, cost of split: 0.0\n", " predicted class: setosa\n", " predicted class: setosa\n", " petal width (cm) < 1.8, cost of split: 0.097\n", " predicted class: versicolor\n", " predicted class: virginica\n" ] } ], "source": [ "print_tree(tree.tree_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing the decision tree" ] }, { "cell_type": "code", "execution_count": 117, "metadata": { "ExecuteTime": { "end_time": "2018-04-09T12:50:28.307583Z", "start_time": "2018-04-09T12:50:28.284228Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on test set: 0.9473684210526315\n" ] } ], "source": [ "all_predictions = []\n", "for i in range(X_test.shape[0]):\n", " result = tree.predict(X_test[i], tree.tree_dict)\n", " all_predictions.append(y_test[i] == result)\n", "\n", "print(f\"Accuracy on test set: {sum(all_predictions) / len(all_predictions)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }