{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# %load /Users/facai/Study/book_notes/preconfig.py\n",
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set(color_codes=True)\n",
"sns.set(font='SimHei')\n",
"plt.rcParams['axes.grid'] = False\n",
"\n",
"from IPython.display import SVG\n",
"\n",
"def show_image(filename, figsize=None):\n",
" if figsize:\n",
" plt.figure(figsize=figsize)\n",
"\n",
" plt.imshow(plt.imread(filename))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"树构建模块 _tree.* 详解\n",
"====================="
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 0. 大纲\n",
"\n",
"此模块包含两种类,一种是二叉树的实现类 `Tree`,另一种是构建出整颗树的方法类 `TreeBuilder`。我们着重介绍构建方法类 `TreeBuilder`,最后会简要提及 `Tree` 的几个函数。 "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SVG(\"./res/uml/Model___tree_3.svg\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. 构建类"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1.0 TreeBuilder\n",
"TreeBuilder 提供了接口方法 `build` 和一个具体的参数检查方法 `_check_input`,没太多说的。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1.1 DepthFirstTreeBuilder\n",
"DepthFirstTreeBuilder 是用类似先序遍历的方式生成整颗决策树,借助的数据结构是栈。\n",
"\n",
"主要流程是:\n",
"\n",
"1. 从栈中抽出一个节点\n",
"2. 计算分割点\n",
" + 若达到叶子条件,不再处理;\n",
" + 若不是叶子,则先压右叶子入栈,再压左叶子入栈。\n",
"3. 若栈空,树构建完成。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"主体代码如下:\n",
"\n",
"```Python\n",
" 195 with nogil:\n",
" 196 # push root node onto stack\n",
" 197 rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)\n",
" 198 #+-- 4 lines: if rc == -1:-----------------------------------------------------------------\n",
" 202\n",
" 203 while not stack.is_empty():\n",
" 204 stack.pop(&stack_record)\n",
" 205 #+-- 9 lines: start = stack_record.start---------------------------------------------------\n",
" 214 n_node_samples = end - start\n",
" 215 splitter.node_reset(start, end, &weighted_n_node_samples)\n",
" 216\n",
" 217 is_leaf = ((depth >= max_depth) or\n",
" 218 (n_node_samples < min_samples_split) or\n",
" 219 (n_node_samples < 2 * min_samples_leaf) or\n",
" 220 (weighted_n_node_samples < min_weight_leaf))\n",
" 221\n",
" 222 if first:\n",
" 223 impurity = splitter.node_impurity()\n",
" 224 first = 0\n",
" 225\n",
" 226 is_leaf = is_leaf or (impurity <= MIN_IMPURITY_SPLIT)\n",
" 227\n",
" 228 if not is_leaf:\n",
" 229 splitter.node_split(impurity, &split, &n_constant_features)\n",
" 230 is_leaf = is_leaf or (split.pos >= end)\n",
" 231\n",
" 232 node_id = tree._add_node(parent, is_left, is_leaf, split.feature,\n",
" 233 split.threshold, impurity, n_node_samples,\n",
" 234 weighted_n_node_samples)\n",
" 235\n",
" 236 #+-- 4 lines: if node_id == (-1):--------------------------------------------------\n",
" 240 # Store value for all nodes, to facilitate tree/model\n",
" 241 # inspection and interpretation\n",
" 242 splitter.node_value(tree.value + node_id * tree.value_stride)\n",
" 243\n",
" 244 if not is_leaf:\n",
" 245 # Push right child on stack\n",
" 246 rc = stack.push(split.pos, end, depth + 1, node_id, 0,\n",
" 247 split.impurity_right, n_constant_features)\n",
" 248 #+-- 3 lines: if rc == -1:-----------------------------------------------------------------\n",
" 251 # Push left child on stack\n",
" 252 rc = stack.push(start, split.pos, depth + 1, node_id, 1,\n",
" 253 split.impurity_left, n_constant_features)\n",
" 254 #+-- 2 lines: if rc == -1:-----------------------------------------------------------------\n",
" 256\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1.2 BestFirstTreeBuilder\n",
"\n",
"BestFristTreeBuilder 总是优先分割最混杂(不纯度最大)的节点,借助了最大堆的数据结构。处理流程和 DepthFirstTreeBuilder 大同小异,不赘述。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"### 2. 实现类\n",
"#### 2.0 Tree\n",
"sklearn 用数组实现了二叉树,我比较感兴趣的函数是计算特征重要性的 `compute_feature_importances`。\n",
"\n",
"这个函数的想法其实也很简单,就是遍历决策的中间节点,汇总各个特征对纯净度的贡献量。代码很短,很好理解。\n",
"\n",
"```Python\n",
"1033 cpdef compute_feature_importances(self, normalize=True):\n",
"1034 \"\"\"Computes the importance of each feature (aka variable).\"\"\"\n",
"1035 #+-- 3 lines: cdef Node* left--------------------------------------------------------------\n",
"1038 cdef Node* node = nodes\n",
"1039 cdef Node* end_node = node + self.node_count\n",
"1040 #+-- 3 lines: cdef double normalizer = 0.--------------------------------------------------\n",
"1043 cdef np.ndarray[np.float64_t, ndim=1] importances\n",
"1044 #+-- 2 lines: importances = np.zeros((self.n_features,))-----------------------------------\n",
"1046\n",
"1047 with nogil:\n",
"1048 while node != end_node:\n",
"1049 if node.left_child != _TREE_LEAF:\n",
"1050 # ... and node.right_child != _TREE_LEAF:\n",
"1051 left = &nodes[node.left_child]\n",
"1052 right = &nodes[node.right_child]\n",
"1053\n",
"1054 importance_data[node.feature] += (\n",
"1055 node.weighted_n_node_samples * node.impurity -\n",
"1056 left.weighted_n_node_samples * left.impurity -\n",
"1057 right.weighted_n_node_samples * right.impurity)\n",
"1058 node += 1\n",
"1059\n",
"1060 importances /= nodes[0].weighted_n_node_samples\n",
"1061\n",
"1062 if normalize:\n",
"1063 normalizer = np.sum(importances)\n",
"1064 #+-- 3 lines: if normalizer > 0.0:---------------------------------------------------------\n",
"1067 importances /= normalizer\n",
"1068\n",
"1069 return importances\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"另外,在 `tree.py` 模块里决策树分类结果也可以计算出概率值,这个概率其实是预测类的样本在此叶子的占比。\n",
"\n",
"整个计算路径是:\n",
"\n",
"1. `_tree.py:Tree.predict` 通过 `_tree.py:Tree.apply` 找到叶子节点,结合 `_tree.py:Tree._get_value_ndarray` 得到所在叶子节点的各个类统计数。\n",
"\n",
"2. `tree.py:DecisionTreeClassifier.predict_proba` 计算占比。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 总结\n",
"\n",
"本文介绍了两种决策树的构建方法,和计算特征重要性与结果预测概率的方法。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}