{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# %load /Users/facai/Study/book_notes/preconfig.py\n", "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from IPython.display import SVG" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "决策树在 sklearn 中的实现简介\n", "============================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 0. 预前\n", "本文简单分析 [scikit-learn/scikit-learn](https://github.com/scikit-learn/scikit-learn) 中决策树涉及的代码模块关系。\n", "\n", "分析的代码版本信息是:\n", "```shell\n", "~/W/s/sklearn ❯❯❯ git log -n 1 study/analyses_decision_tree\n", "commit d161bfaa1a42da75f4940464f7f1c524ef53484f\n", "Author: John B Nelson \n", "Date: Thu May 26 18:36:37 2016 -0400\n", "\n", " Add missing double quote (#6831)\n", "```\n", "\n", "本文假设读者已经了解决策树的其本概念,阅读 [sklearn - Decision Trees](http://scikit-learn.org/stable/modules/tree.html) 有助于快速了解。 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 总纲\n", "\n", "决策树的代码位于 `scikit-learn/sklearn/tree` 目录下,各文件大意说明如下:\n", "\n", "```\n", "tree\n", "+-- __init__.py\n", "+-- setup.py \n", "+-- tree.py 主文件\n", "+-- export.py 导出树模型\n", "+-- _tree.* 组建树的类\n", "+-- _splitter.* 分割方法\n", "+-- _criterion.* 不纯度评价\n", "+-- _utils.* 辅助数据结构:栈和最小堆 \n", "+-- tests/\n", " +-- __init__.py\n", " +-- test_tree.py\n", " +-- test_export.py\n", "```\n", "\n", "类之间的大致关系如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "DecisionTreeClassifier+predict_proba()+predict_log_proba()BaseDecisionTree+fit()+predict()DecisionTreeRegressorExtraTreeClassifierExtraTreeRegressorTreeBuilder+splitter+min_samples_split+min_samples_leaf+min_weight_leaf+max_depth+build()+_check_input()Splitter+node_impurity()+node_reset()+node_split()+node_value()Criterion+proxy_impurity_improvement()+impurity_improvement()+1..1+1..1+1..1" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SVG(\"./res/uml/Model__tree_0.svg\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Tree.py` 下定义了 `BaseDecisionTree` 基类,实现了完整的分类和回归功能,衍生出的子类主要用于封装初始化参数。两种子类的区别在于:`DecisionTree*` 类会遍历特征和值,从而找到最佳分割点,而 `ExtraTree*` 类会随机抽取特征和值,来寻找分割点。\n", "\n", "下面是基类的训练方法 `fit` 流程:\n", "\n", "1. 检查参数。\n", "2. 设置评价函数。\n", "3. 创建分割方法:根据数据是否稀疏阵,生成相应类。\n", "4. 创建树:根据叶子数决定用深度优先还是评价优先。\n", "5. 调用树的构建方法:生成决策树。\n", "\n", "代码如下,细节已经折叠:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", " 72 class BaseDecisionTree(six.with_metaclass(ABCMeta, BaseEstimator,\n", " 73 _LearntSelectorMixin)):\n", " 74 \"\"\"Base class for decision trees.\n", " 75 #+-- 3 lines: Warning: This class should not be used directly.-------------------\n", " 78 \"\"\"\n", " 79\n", " 80 @abstractmethod\n", " 81 def __init__(self,\n", " 82 #+-- 30 lines: criterion,---------------------------------------------------------\n", " 112\n", " 113 def fit(self, X, y, sample_weight=None, check_input=True,\n", " 114 X_idx_sorted=None):\n", " 115 \"\"\"Build a decision tree from the training set (X, y).\n", " 116 #+-- 34 lines: Parameters---------------------------------------------------------\n", " 150 \"\"\"\n", " 151\n", " 152 #+--180 lines: random_state = check_random_state(self.random_state)---------------\n", " 332\n", " 333 # Build tree\n", " 334 criterion = self.criterion\n", " 335 #+-- 6 lines: if not isinstance(criterion, Criterion):---------------------------\n", " 341\n", " 342 SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS\n", " 343 #+-- 9 lines: splitter = self.splitter-------------------------------------------\n", " 352\n", " 353 self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_)\n", " 354\n", " 355 # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise\n", " 356 if max_leaf_nodes < 0:\n", " 357 builder = DepthFirstTreeBuilder(splitter, min_samples_split,\n", " 358 min_samples_leaf,\n", " 359 min_weight_leaf,\n", " 360 max_depth)\n", " 361 else:\n", " 362 builder = BestFirstTreeBuilder(splitter, min_samples_split,\n", " 363 min_samples_leaf,\n", " 364 min_weight_leaf,\n", " 365 max_depth,\n", " 366 max_leaf_nodes)\n", " 367\n", " 368 builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)\n", " 369\n", " 370 #+-- 3 lines: if self.n_outputs_ == 1:-------------------------------------------\n", " 373\n", " 374 return self\n", "```" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "预测方法 `predict` 代码非常简单,调用 `tree_.predict()` 方法取得预测值:如果是分类问题,输出预测值最大的类;如果是回归问题,直接输出。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```Python\n", " 398 def predict(self, X, check_input=True):\n", " 399 \"\"\"Predict class or regression value for X.\n", " 400 +-- 20 lines: For a classification model, the predicted class for each sample in X\n", " 420 \"\"\"\n", " 421\n", " 422 X = self._validate_X_predict(X, check_input)\n", " 423 proba = self.tree_.predict(X)\n", " 424 n_samples = X.shape[0]\n", " 425\n", " 426 # Classification\n", " 427 if isinstance(self, ClassifierMixin):\n", " 428 if self.n_outputs_ == 1:\n", " 429 return self.classes_.take(np.argmax(proba, axis=1), axis=0) \n", " 430\n", " 431 +--- 9 lines: else:--------------------------------------------------------------\n", " 440\n", " 441 # Regression\n", " 442 else:\n", " 443 if self.n_outputs_ == 1:\n", " 444 return proba[:, 0]\n", " 445 +--- 3 lines: else:--------------------------------------------------------------\n", " ```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "sklearn 的决策树是 CART(Classification and Regression Trees) 算法,分类问题会转换成预测概率的回归问题,所以两类问题的处理方法是相同的,主要区别在评价函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2 模块简介" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 2.0 评价函数\n", "`_criterion.*` 是评价函数相关的文件,用 Cython 实现,*.pxd 和 *.pyx 文件分别对等 C 语言中的 *.h 和 *.c 文件。\n", "\n", "下面是类的关系图:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "Criterion+proxy_impurity_improvement()+impurity_improvement()ClassificationCriterion+init()+reset()+reverse_reset()+update()+node_value()RegressionCriterion+init()+reset()+reverse_reset()+update()+node_value()Entropy+node_impurity()+children_impurity()Gini+node_impurity()+children_impurity()MSE+node_impurity()+children_impurity()+proxy_impurity_improvement()FriedmanMSE+proxy_impurity_improvement()+impurity_improvement()" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SVG(\"./res/uml/Model___criterion_1.svg\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "+ 对于分类问题,sklearn 提供了 Gini 和 Entropy 两种评价函数;\n", " - 默认会用 Gini 。\n", " - [Decision Trees: “Gini” vs. “Entropy” criteria](https://www.garysieling.com/blog/sklearn-gini-vs-entropy-criteria)\n", "\n", "+ 对于回归问题,则提供了 MSE(均方差)和 FriedmanMSE。\n", " - 默认会用 MSE 。\n", " - FriedmanMSE 用于 gradient boosting。\n", " \n", "在实际使用中,我们应该都测试下同的评价函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 2.1 分割方法\n", "`_splitter.*` 是分割方法相关的文件。\n", "\n", "下面是类的关系图:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "Splitter+node_impurity()+node_reset()+node_split()+node_value()BaseDenseSplitter+init()BestSplitter+node_split()RandomSplitter+node_split()BaseSparseSplitter+init()+extract_nnz()BestSparseSplitter+node_split()RandomSparseSplitter+node_split()Best:遍历一个特征的值以确定最佳阈值;Random:在一个特征的最大值和最小值之间随机抽样一个值作为最佳阈值。«dataType»SplitRecord+feature+pos+threshold+improvement+impurity_left+impurity_right+1..*" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SVG(\"./res/uml/Model___splitter_2.svg\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Splitter` 基类依数据储存方式(实阵或稀疏阵)衍生为 `BaseDenseSplitte` 和 `BaseSparseSplitter`。在这之下根据阈值的寻优方法再细分两类:`Best*Splitter` 会遍历特征的可能值,而 `Random*Splitter` 则是随机抽取。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 2.2 树的组建方法\n", "`_tree.*` 是树组建方法相关的文件。\n", "\n", "下面是类的关系图:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "«dataType»Node+left_child+right_child+feature+threshold+impurity+n_node_samples+weighted_n_node_samplesTree+_add_note()+_resize()+predict()+apply()+decision_path()+compute_feature_importances()TreeBuilder+splitter+min_samples_split+min_samples_leaf+min_weight_leaf+max_depth+build()+_check_input()+1..1+1..*DepthFirstTreeBuilder+build()_utils.Stack+push()+pop()+1..1BestFirstTreeBuilder+build()_utils.PriorityHeap+push()+pop()+1..1" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "SVG(\"./res/uml/Model___tree_3.svg\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "sklearn 提供了两种树的组建方法:一种是用栈实现的深度优先方法,它会先左后右地生成整颗决策树;另一种是用最大堆实现的最优优先方法,它每次在纯净度提升最大的节点进行分割生长。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3 结语\n", "本文简单介绍了 sklearn 中决策树的实现框架,后面会对各子模块作进一步的详述。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }