{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 简介\n", "\n", "先看一个例子,某银行是否给用户放贷的判断规则集如下: \n", "\n", "```python\n", "if 年龄==青年:\n", " if 有工作==是:\n", " if 信贷情况==非常好:\n", " 放\n", " else:\n", " 不放\n", " else:\n", " if 有自己的房子==是:\n", " if 信贷情况==一般:\n", " 不放\n", " else:\n", " 放\n", " else:\n", " if 信贷情况==非常好 or 信贷情况==好:\n", " 放\n", " else:\n", " if 有工作==是:\n", " 放\n", " else:\n", " 不放\n", "elif 年龄==中年:\n", " if 有自己的房子==是:\n", " 放\n", " else:\n", " if 信贷情况==非常好 or 信贷情况==好:\n", " 放\n", " else:\n", " if 有工作==是:\n", " 放\n", " else:\n", " 不放\n", "elif 年龄==老年:\n", " if 有自己的房子==是:\n", " if 信贷情况==非常好 or 信贷情况==好:\n", " 放\n", " else:\n", " 不放\n", " else:\n", " if 信贷情况==非常好 or 信贷情况==好:\n", " if 有工作==是:\n", " 放\n", " else:\n", " 不放\n", " else:\n", " 不放\n", "if 有自己的房子==是:\n", " 放\n", "else:\n", " if 有工作==是:\n", " 放\n", " else:\n", " 不放\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "眼力好的同学立马会发现这代码写的有问题,比如只要`信贷情况==非常好`的用户都有放款,何必嵌到里面去?而且很多规则有冗余,为什么不重构一下呀?但现实情况是你可能真不敢随意乱动!因为指不定哪天项目经理又要新增加规则了,所以宁可让代码越来越冗余,越来越复杂,也不敢随意乱动之前的规则,乱动两条,可能会带来意想不到的灾难。简单总结一下这种复杂嵌套的`if else`规则可能存在的痛点: \n", "\n", "(1)规则可能不完备,存在某些匹配不上的情况; \n", "\n", "(2)规则之间存在冗余,多个`if else`情况其实是判断的同样的条件; \n", "\n", "(3)严重时,可能会出现矛盾的情况,即相同的条件,即有**放**,又有**不放**; \n", "\n", "(4)判断规则的优先级混乱,比如`信贷情况`因子可以优先考虑,因为只要它是`非常好`就可以放款,而不必先判断其它条件 \n", "\n", "而决策树算法就能解决以上痛点,它能保证所有的规则**互斥且完备**,即用户的任意一种情况一定能匹配上一条规则,且该规则唯一,这样就能解决上面的痛点1~3,且规则判断的优先级也很不错,下面介绍决策树学习算法。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 决策树学习\n", "决策树算法可以从已标记的数据中自动学习出`if else`规则集,如下图([图片来源>>>](https://www.cnblogs.com/jin-liang/p/9609144.html)),左边是收集的一系列判断是否打球的案例,包括4个特征outlook,temperature,Humidity,Wind,以及y标签是否打球,通过决策树学习后得到右边的决策树,**决策树的结构**如图所示,它由节点和有向边组成,而节点又分为两种:叶子节点和非叶子节点,非叶子节点主要用于对某一特征做判断,而它下面所链接的有向边表示该特征所满足的某条件,最终的叶子节点即表示实例的预测值(分类/回归) \n", "\n", "\n", "![avatar](./source/09_决策树学习.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "决策树学习主要分为两个阶段,**决策树生成**和**决策树剪枝**,决策树生成阶段最重要便是**特征选择**,下面对相关概念做介绍: " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 1.特征选择 \n", "\n", "特征选择用于选择对分类有用的特征,ID3和C4.5通常选择的准则是信息增益和信息增益比,下面对其作介绍并实现\n", "\n", "##### 信息增益\n", "首先介绍两个随机变量之间的互信息公式: \n", "\n", "$$\n", "MI(Y,X)=H(Y)-H(Y|X)\n", "$$ \n", "\n", "这里$H(X)$表示$X$的熵,在最大熵模型那一节已做过介绍: \n", "\n", "$$\n", "H(X)=-\\sum_{i=1}^np_ilogp_i,这里p_i=P(X=x_i)\n", "$$ \n", "\n", "条件熵$H(Y|X)$表示在已知随机变量$X$的条件下,随机变量$Y$的不确定性: \n", "\n", "$$\n", "H(Y|X)=\\sum_{i=1}^np_iH(Y|X=x_i),这里p_i=P(X=x_i)\n", "$$ \n", "\n", "而信息增益就是$Y$取分类标签,$X$取某一特征时的互信息,它表示如果选择特征$X$对数据进行分割,可以使得分割后$Y$分布的熵降低多少,若降低的越多,说明分割每个子集的$Y$的分布越集中,则$X$对分类标签$Y$越有用,下面进行python实现: " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "定义计算熵的函数,封装到ml_models.utils\n", "\"\"\"\n", "import numpy as np\n", "from collections import Counter\n", "import math\n", "def entropy(x,sample_weight=None):\n", " x=np.asarray(x)\n", " #x中元素个数\n", " x_num=len(x)\n", " #如果sample_weight为None设均设置一样\n", " if sample_weight is None:\n", " sample_weight=np.asarray([1.0]*x_num)\n", " x_counter={}\n", " weight_counter={}\n", " # 统计各x取值出现的次数以及其对应的sample_weight列表\n", " for index in range(0,x_num):\n", " x_value=x[index]\n", " if x_counter.get(x_value) is None:\n", " x_counter[x_value]=0\n", " weight_counter[x_value]=[]\n", " x_counter[x_value]+=1\n", " weight_counter[x_value].append(sample_weight[index])\n", " \n", " #计算熵\n", " ent=.0\n", " for key,value in x_counter.items():\n", " p_i=1.0*value*np.mean(weight_counter.get(key))/x_num\n", " ent+=-p_i*math.log(p_i)\n", " return ent" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6931471805599453" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#测试\n", "entropy([1,2])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def cond_entropy(x, y,sample_weight=None):\n", " \"\"\"\n", " 计算条件熵:H(y|x)\n", " \"\"\"\n", " x=np.asarray(x)\n", " y=np.asarray(y)\n", " # x中元素个数\n", " x_num = len(x)\n", " #如果sample_weight为None设均设置一样\n", " if sample_weight is None:\n", " sample_weight=np.asarray([1.0]*x_num)\n", " # 计算\n", " ent = .0\n", " for x_value in set(x):\n", " x_index=np.where(x==x_value)\n", " new_x=x[x_index]\n", " new_y=y[x_index]\n", " new_sample_weight=sample_weight[x_index]\n", " p_i=1.0*len(new_x)/x_num\n", " ent += p_i * entropy(new_y,new_sample_weight)\n", " return ent" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#测试\n", "cond_entropy([1,2],[1,2])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def muti_info(x, y,sample_weight=None):\n", " \"\"\"\n", " 互信息/信息增益:H(y)-H(y|x)\n", " \"\"\"\n", " x_num=len(x)\n", " if sample_weight is None:\n", " sample_weight=np.asarray([1.0]*x_num)\n", " return entropy(y,sample_weight) - cond_entropy(x, y,sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,做一个测试,看特征的取值的个数对信息增益的影响" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "#作epochs次测试\n", "epochs=100\n", "#x的取值的个数:2->class_num_x\n", "class_num_x=100\n", "#y标签类别数\n", "class_num_y=2\n", "#样本数量\n", "num_samples=500\n", "info_gains=[]\n", "for _ in range(0,epochs):\n", " info_gain=[]\n", " for class_x in range(2,class_num_x):\n", " x=[]\n", " y=[]\n", " for _ in range(0,num_samples):\n", " x.append(random.randint(1,class_x))\n", " y.append(random.randint(1,class_num_y))\n", " info_gain.append(muti_info(x,y))\n", " info_gains.append(info_gain)\n", "plt.plot(np.asarray(info_gains).mean(axis=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现一个很有意思的现象,如果特征的取值的个数越多,越容易被选中,这比较好理解,假设一个极端情况,若对每一个实例特征$x$的取值都不同,则其$H(Y|X)$项为0,则$MI(X,Y)=H(Y)-H(Y|X)$将会取得最大值($H(Y)$与$X$无关),这便是ID3算法的一个痛点,为了矫正这一问题,C4.5算法利用信息增益比作特征选择" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 信息增益比\n", "信息增益比其实就是对信息增益除以了一个$x$的熵: \n", "\n", "$$\n", "\\frac{MI(X,Y)}{H(X)}\n", "$$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def info_gain_rate(x, y,sample_weight=None):\n", " \"\"\"\n", " 信息增益比\n", " \"\"\"\n", " x_num=len(x)\n", " if sample_weight is None:\n", " sample_weight=np.asarray([1.0]*x_num)\n", " return 1.0 * muti_info(x, y,sample_weight) / (1e-12 + entropy(x,sample_weight))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来再作一次相同的测试:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#作epochs次测试\n", "epochs=100\n", "#x的取值的个数:2->class_num_x\n", "class_num_x=100\n", "#y标签类别数\n", "class_num_y=2\n", "#样本数量\n", "num_samples=500\n", "info_gain_rates=[]\n", "for _ in range(0,epochs):\n", " info_gain_rate_=[]\n", " for class_x in range(2,class_num_x):\n", " x=[]\n", " y=[]\n", " for _ in range(0,num_samples):\n", " x.append(random.randint(1,class_x))\n", " y.append(random.randint(1,class_num_y))\n", " info_gain_rate_.append(info_gain_rate(x,y))\n", " info_gain_rates.append(info_gain_rate_)\n", "plt.plot(np.asarray(info_gain_rates).mean(axis=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "虽然整体还是上升的趋势,当相比于信息增益已经缓解了很多,将它们画一起直观感受一下: " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8lNX5///XlW2yEZYksm8KaN2rEWlra60bVQtaUan9VLRavq3WaqtttbUufD5t1fbn0qqtqFh3FKqIe61b3UtwQVCRRZSAQlgSsm9z/f44A4R9kCSTzLyfj8c8mLnveybXzQ3XnJz7nOuYuyMiIqkjLdEBiIhIx1LiFxFJMUr8IiIpRolfRCTFKPGLiKQYJX4RkRQTV+I3s9FmNt/MFprZJVvZ/w0ze8vMms1sXKvtB5rZ62Y2z8zmmNlpbRm8iIjsPNvROH4zSwc+Ao4GyoBZwPfc/f1WxwwBCoCLgZnuPj22fQTg7r7AzPoBs4EvuXtF25+KiIjEIyOOY0YCC919MYCZTQXGAhsSv7svie2Ltn6ju3/U6vlyM1sJFANK/CIiCRJP4u8PLG31ugw4dGd/kJmNBLKARVvZNxGYCJCXl3fwXnvttbMfLyKS0mbPnr3K3YvjOTaexG9b2bZTdR7MrC9wDzDB3aOb73f3ycBkgJKSEi8tLd2ZjxcRSXlm9km8x8Zzc7cMGNjq9QBg+U4EUwA8AVzm7m/E+z4REWkf8ST+WcBwMxtqZlnAeGBmPB8eO/4R4G53n/bFwxQRkbayw8Tv7s3AT4FngA+Ah9x9nplNMrMxAGZ2iJmVAacAt5rZvNjbTwW+AZxpZu/EHge2y5mIiEhcdjics6Opj19EZOeZ2Wx3L4nnWM3cFRFJMUr8IiIpRolfRCTFxDOOX0RE2sqDD8L8+ZCZCVlZEIlAdjbk5ECfPnDkke0eghK/iEhHeewxGD9+2/tHjVLiFxFJGuvWwbnnwn77wX//G7Y1NoZHfX14pHVM77sSv4hIR/jNb2DZMpg+PXTtwMY/O5hu7oqItLdXX4VbboGf/QwO3ekal21OiV9EpD3V1sKPfgQDB8L//V+iowHU1SMi0n5aWuD00+HDD+GppyA/P9ERAUr8IiLtwz107Tz6KPzlL3DssYmOaAN19YiItIdrrw39+r/8JZx/fqKj2YQSv4hIW/vPf+CSS8KY/auvTnQ0W1DiFxHZGTffHLpw6uu3fcw118Buu8GUKR02Nn9nqI9fRCReb70FF1wQbtrOmgUzZkDv3pse88EH8OSTcNVVoQxDJ9T5vopERDqjpiY4+2woLg4t+XffDWPy587d9LjrrgsTs37yk8TEGQclfhGRePz5z/DOO+GG7VlnhX78hgY47DB4I7ac+IoVcM89MGFC+ILopJT4RUR2ZP780HVz8slw0klhW0kJvPkmFBXBMcfAK6+EL4WGBvj5zxMb7w6oj19EZHvc4cc/Dv31N9206b5Bg+Cll0JFzdGjQ6nl73wH9twzMbHGSS1+EZHtefJJePHFUG6hT58t9/fvH/YPHgwVFfCLX3R0hDtNLX4RkW2JRuHSS2GPPWDixG0f16cPvPxyKLd8+OEdF98XpMQvIrItDzwA770H998funG2p1ev0N3TBairR0Rkaxob4Xe/gwMPhNNOS3Q0bUotfhGRrZk8GT7+OFTV7ISzb3dFcp2NiMjOqKiAO+4Is23Xa2yEG24IK2Z985udqqpmW1GLX0RS0wsvhIlWS5eG16NGwdixYVbuggVhbP6tt4JZYuNsB3G1+M1stJnNN7OFZnbJVvZ/w8zeMrNmMxu32b4JZrYg9pjQVoGLiHwhjY1w8cVh7H12NjzzTJiVW1kZRvCkp8MTT8DTT8OQIYmOtl2Yu2//ALN04CPgaKAMmAV8z93fb3XMEKAAuBiY6e7TY9t7AaVACeDAbOBgd1+7rZ9XUlLipaWlX/yMRES2paYmzLx99tlQS+dPf4K8vLDPHRYvDpOydjSCpxMys9nuXhLPsfG0+EcCC919sbs3AlOBsa0PcPcl7j4HiG723mOBZ919TSzZPwt0jfFOIpJc1q6Fo4+G554L3Tm33LIx6UPo0tljjy6Z9HdWPIm/P7C01euy2LZ4xPVeM5toZqVmVlpeXh7nR4uIbGbNmtBls7mVK8ON2tJSmDYtFFlLYfEk/q3d2dh+/9BOvtfdJ7t7ibuXFHfiinYi0skdeyyceOKW23/721Bo7fHH4bvf7fi4Opl4En8ZMLDV6wHA8jg/f1feKyISv08+CS36F1+EDz/cuL26GqZOhe99L4zUkbgS/yxguJkNNbMsYDwwM87PfwY4xsx6mllP4JjYNhGRtvXYY+HPtLQwNn+96dND8j/77MTE1QntMPG7ezPwU0LC/gB4yN3nmdkkMxsDYGaHmFkZcApwq5nNi713DfC/hC+PWcCk2DYRkbb12GMwYkQYi3/XXWHYJoQbucOHw9e+ltj4OpEdDufsaBrOKSI7bd26sCDKBRfAEUfA8cfDP/8J++4bauP/8Y9wyRZTkJLKzgzn1MxdEen6/vWvsCbud74TWvb9+8Ptt8MBB4QJWRM0d7Q11eoRka5v5sxQFvmrXw2J/oc/DDNvb7sNvv1t6Ns30RF2Kkr8ItK1tbSEVbKOOw4yYp0YP/xh+HP16o3PZQMlfhHpelrfm3z99ZDgx4zZuG3IkDB0s3dvOOGEDg+vs1Mfv4h0HcuXw7XXhlr5e+0VWvNz54YyC5uXT77nHqiqSokSDDtLiV9EOr/m5lBR8+9/D8/HjYOPPoLzzw/7jzoKCgo2fU9xcXjIFpT4RaTzmzEDbrwRzjgDrrgCdt89bH/nHXjwwU27eWSHlPhFpPObPj203qdMCaN21jvwwPCQnaKbuyLSudXVbSyu1jrpyxemxC8indvTT4cFVE45JdGRJA0lfhHp3KZNg8JCOPzwREeSNJT4RaTzqq8PxddOOmnj5CzZZUr8ItK5tJ6c9a9/hZLK6uZpU0r8ItI5fP45nHZamG17113hC2DaNOjZM1TclDaj351EpOPddVcotTByZHi8+WaYoFVXF2bknnkm3H8/vPEGnHyyZt+2MSV+EWk/8+fD0KGQlbVxW0UFnHdeSPK33rpx++GHh1IMw4aFGbq//nXo5hk3ruPjTnJK/CLSPt57L0yuOvPMTZdCnDIlDM+cPRvy8kJrPzc3jNNPi/U+n3tuqK3/7LMwenRCwk9mWoFLRNrHCSfAE0+EZD5nDuyzT6izM2wYDB4ML72U6AiTys6swKWbuyLS9l56KST9X/0K8vPhN78J22fOhE8+gQsvTGx8KU5dPSLSttxD/3z//nDlldCjR0j8r7wCN9wQauWrqFpCqcUvIm1rxozQb3/VVZCTExZA79s39PW//HIopayaOwmlxC8ibae5GS69FL70pY0LnOfmhi+BRYtCt8/ZZyc2RlFXj4i0oVtvDUM4Z8zYtMTCWWfBnXeGBVO6d09cfAIo8YtIW1m7Fi6/PMyy3bwPPyMDXnstMXHJFtTVIyJt46qrwuSsG24As0RHI9uhxC8iu+7DD+Hmm+Gcc2D//RMdjexAXInfzEab2XwzW2hml2xlf8TMHoztf9PMhsS2Z5rZXWb2npl9YGaXtm34IpIQLS1hsfOlS0Mr/6KLwk3c//3fREcmcdhhH7+ZpQM3A0cDZcAsM5vp7u+3OuxsYK27DzOz8cA1wGnAKUDE3fczs1zgfTN7wN2XtPWJiEgHqKkJN2mvuw4+/njTfX/6E+y2W2Likp0Sz83dkcBCd18MYGZTgbFA68Q/Frgy9nw6cJOZGeBAnpllADlAI7CubUIXkQ41Y0YYirlmDYwaFYZtmkFVVSjL8JOfJDpCiVM8ib8/sLTV6zLg0G0d4+7NZlYJFBK+BMYCnwG5wM/dfc3mP8DMJgITAQYNGrSTpyAi7W7NmtB/P3BgKLvwta8lOiLZBfH08W/t9vzmld22dcxIoAXoBwwFLjKz3bc40H2yu5e4e0lxcXEcIYlIh/rd78JwzbvvVtJPAvEk/jJgYKvXA4Dl2zom1q3THVgDnA487e5N7r4SeBWIq3qciHQSb78d6uOfd55G7CSJeBL/LGC4mQ01syxgPDBzs2NmArH52YwDnvdQ7/lT4FsW5AGjgA/bJnQRaXfuobZOYSFMmpToaKSN7LCPP9Zn/1PgGSAdmOLu88xsElDq7jOBO4B7zGwhoaU/Pvb2m4E7gbmE7qA73X1OO5yHiLSHe+6BV1+F228PVTYlKWghFhHZuunT4fvfh4MPDiWV0zTfszPTQiwismv+9jc49VQ45JCNq2hJ0lCRNhGBp5+GuXPDLNzFi+GBB8Kat1Onhhm5klSU+EVS3e9/D5ddFp6npYWyyeeeCzfeuGlpZUkauqoiqco9lFH+v/+DH/wAbroJunVTZc0UoMQvkiqiUXjqKaiuDkXWXnkl9OWfc05YQEX9+ClDiV8kVdxxB0ycuOm2888P9fOV9FOKEr9IKmhoCCWTR44M1TXT08NN24EDd/xeSTpK/CKp4LbbQu38O+6AvfdOdDSSYPr9TiTZVFTAm29ufF1XB3/4A3z962Gxc0l5SvwiyeYnPwn18idMgMrKUGDts89CV49G7Ajq6hFJLkuXwrRpcOCBcN998OKLUFsLRx4Jhx+e6Oikk1CLXySZ3HJLGJ//yCOhuFokAqtWqbKmbEItfpFkUVsbxuOfeCIMGRIeb78NCxfCAQckOjrpRNTiF0kW994bVsm68MKN2/LylPRlC0r8Il3Z+rLq7qG2zkEHwWGHJTYm6fTU1SPSFTU3wze+AR9/DMceC4MHw/vvhzVxNXJHdkCJX6Qruu46eP11OOYYeOwxWLMG+vQJNfRFdkCJX6SrWbAArrgCTjoJHn44FFwrLQ1LI0YiiY5OugAlfpGuxD0UWotEQhllCHV3Dj00sXFJl6LEL9KVTJkSJmVNngz9+iU6GumilPhFOrtoFJ57LtTOnzkzzMA9++xERyVdmBK/SGfxwQeh//7TT8Nj2bLwWLgw/FlUBBddBBdfrPr5skuU+EU6gylTNm3FZ2VB//7h8fWvh4XPTz5ZN2+lTSjxiyRaYyNceWVYJOWvfw1j8ouL1aqXdqPEL5Jod90VqmredltI/iLtTE0KkY7kHm7WrtfUFBZJGTkyTMYS6QBxJX4zG21m881soZldspX9ETN7MLb/TTMb0mrf/mb2upnNM7P3zCy77cIX6ULc4fTTYY894JVXwrb77oMlS+Dyy1VqQTrMDrt6zCwduBk4GigDZpnZTHd/v9VhZwNr3X2YmY0HrgFOM7MM4F7gB+7+rpkVAk1tfhYiXcEtt8DUqdCzZxiS+bvfhcR/0EFw3HGJjk5SSDwt/pHAQndf7O6NwFRg7GbHjAXuij2fDhxpZgYcA8xx93cB3H21u7e0TegiXcg778AvfhES/JIl8P3vw1VXhaGav/udWvvSoeK5udsfWNrqdRmw+fzwDce4e7OZVQKFwAjAzewZoBiY6u7Xbv4DzGwiMBFg0KBBO3sOIp1bdTWMHx/G4f/jH1BQEKpojh4Nb70FY8YkOkJJMfEk/q01RTzOYzKAw4BDgFrgOTOb7e7PbXKg+2RgMkBJScnmny3StZ1/Pnz0ETz/fBimud7pp4eHSAeLp6unDBjY6vUAYPm2jon163cH1sS2v+Tuq9y9FngSOGhXgxbpMqZNC6383/4WvvnNREcjAsSX+GcBw81sqJllAeOBmZsdMxOYEHs+Dnje3R14BtjfzHJjXwiHA+8jkgrKyuD//b8wVPPyyxMdjcgGO+zqifXZ/5SQxNOBKe4+z8wmAaXuPhO4A7jHzBYSWvrjY+9da2bXEb48HHjS3Z9op3MR6TyiUTjzzDAr9957ITMz0RGJbBDXzF13f5LQTdN62+WtntcDp2zjvfcShnSKJL9oNIzamTIlVNS87TYYPjzRUYlsQiUbRNpCZSWcckqYmFVXF7Z997sqnyydkhK/SFs4//wwaue882C//WDvvUPfvsbnSyekxC+yqx58EO65J6yDe+WViY5GZIdUpE1kVyxdCj/+cVjz9rLLEh2NSFyU+EW+qHXrYMKEUGHz3nshQ79AS9egf6ki27JiBcydGx4rV8KwYbDnnmF1rDvuCN07NTVw551hn0gXocQvsrnPP4ezzoKnn964LS1t0zr62dmh/s6558Ihh3R8jCK7QIlfpLVnnoEzzgjdOJMmwVe/Gkbp9OoVxufPnw+rV8Pxx0NhYaKjFflClPhF1vvTn+BXv4J99w1DM/fZZ9P9w4apS0eSgm7uigAsWxZG5YwdC//975ZJXySJKPGLAFx9dejDv+EGyMlJdDQi7UqJX2TZslBT58wzYciQREcj0u6U+EWuuQZaWuA3v0l0JCIdQolfUtuyZTB5cmjtDx2a6GhEOoRG9UjqeeutsBRiTQ089pha+5JylPgltfzznzBu3KbbLrpIrX1JKUr8kjreeAP+53/gK1+B22+H/HzIy9NELEk5SvySGhYvhjFjoF8/ePRRKC5OdEQiCaObu5L8li+H444LfflPPaWkLylPiV+S21tvhZWwli2DGTNgxIhERySScOrqka6vsjIUVauqgoYG6NYNCgrC+rc/+AEUFcGrr8L++yc6UpFOQYlfuq7m5rAQyv33b/uYQw8NLf0+fTouLpFOTolfuqbm5jBC58EH4YILQunk/PywSEp1dfgNID09fDGo9o7IJpT4pWuIRsMSh5FIuEk7YUJI+tdeC7/8ZaKjE+lSlPil81u9Go44At57LyT+nByoqIA//lFJX+QLUOKXzq2hAU46KZRYuOwyaGwMSX/UqLA8oojstLgSv5mNBm4E0oHb3f3qzfZHgLuBg4HVwGnuvqTV/kHA+8CV7v7ntgldkp47/PCH8PLLMHUqnHZaoiMSSQo7HMdvZunAzcC3gb2B75nZ3psddjaw1t2HAdcD12y2/3rgqV0PV1LGmjWhhs7998Mf/qCkL9KG4mnxjwQWuvtiADObCowltODXGwtcGXs+HbjJzMzd3cxOBBYDNW0WtSQnd7j7brjvPnjhhTBy50c/gksuSXRkIkklnpm7/YGlrV6XxbZt9Rh3bwYqgUIzywN+DVy1vR9gZhPNrNTMSsvLy+ONXZLNlVeGuvgffxxa+//9L9x6K5glOjKRpBJPi39r/+s8zmOuAq5392rbzn9ed58MTAYoKSnZ/LMlFdx5J0yaFBL/lClK9iLtKJ7EXwYMbPV6ALB8G8eUmVkG0B1YAxwKjDOza4EeQNTM6t39pl2OXJLHv/8NEyfCUUeF1bCU9EXaVTyJfxYw3MyGAsuA8cDpmx0zE5gAvA6MA553dwe+vv4AM7sSqFbSl03MmQMnnwxf+hJMnw6ZmYmOSCTp7bCPP9Zn/1PgGeAD4CF3n2dmk8xsTOywOwh9+guBXwC6GycbtbSElvzPfhZKKay3dGkol5yfD088Ad27Jy5GkRRioWHeeZSUlHhpaWmiw5C28uKLoZbOnDnh9YgR8PDD0L8/HHZYSP4vv6zKmSK7yMxmu3tJPMeqHr+0n1/9KpRaqKyEadPCl0BlZaiYecQRYTbuI48o6Yt0MJVskPYxYwb86U9wzjnwl79srJD51ltw6qmhPv6998K3vpXYOEVSkBK/tL2yMjj7bDj4YLj55lAqeb1+/cLkrE8/hT32SFyMIilMXT3StlpaQp38hgZ44IFNk/56mZlK+iIJpBa/tJ2lS+Hqq+Gll+Af/4DhwxMdkYhshRK/7Lq77w6lFV57Lbz+0Y/gjDMSG5OIbJO6emTX/P3vYTWsdevg97+HBQs0+1akk1OLX+K3bFm4Obs+qT/5JJx3Hhx/fBjFk6F/TiJdgVr8smPRaCiNPGAA7LMP/PWvYUz+qafCgQeGRVKU9EW6DP1vle1raAgVM6dODYl+yZJQegFg4EB4/PFQckFEugwlftm2igoYOxb+8x+45pqwsLkZzJ4NDz4YlkXs2zfRUYrITlLil61zD4uZv/76luvdHnxweIhIl6TEL1t3883hhu1112m9W5Eko5u7sqW33w5LHx5/PFx4YaKjEZE2psQvm1q3LrTwi4vD7FuNxxdJOurqkTBc8+WXQ7XMadOgqgqefx6KihIdmYi0AyX+VFddDSedFNa9zcuD7343rH972GGJjkxE2okSfyqrqAj9+G+8ATfeGEop5+UlOiqRpNbSUo9ZGmlpW6lc20GU+FPVqlVwzDEwdy489FBY8FxE2k002kRZ2Q0sWXIV0WgNmZm9iUQGkJu7J926HUK3biV06/Zl0tPbv/GlxJ+K5s+HMWPCYigzZoQFz0WkXTQ3V7Nu3RssXHghtbXzKCw8gW7dDqGhoYyGhqVUVLzEypX3A5CXtx+HHDKn3WNS4k81zzwTRu1kZsK//gVf/3qiIxLp9KLRJtwbN2mN19UtYs2ap6mqKqWhYTmNjctpalpDWlo26em5QDqNjctoaloFQCQyiH33fZSiojFbfH5Dw2dUVc3GvalDzkeJP1U0NobJWL/9Ley7Lzz6KAwZkuioRDq91aufZP78c2hs/IzMzGKys4fS1LSa+vpFAGRl9SESGUROzjC6dSvEvYGWllrcGykoGEV29hBycoZSWHjCNrtxIpG+RCIndNg5KfEnO3d4+OFQXXPhQhg3Du68U4XVJGW5R6mqmk1d3QLq6hbR2LicrKy+ZGfvTk7OUDIze5OZWYRZGosWXcRnn91OXt5+9O9/HvX1n1Bf/zFZWb0ZMOBCevUaTW7usESf0k5T4k8Gn3wCf/wj/OIXMGLExu2LFoVFUl59FfbeO1TSPO44TcqSpBSNNlJVNYtotJG8vH3Iytpti2PWrn2RRYsuorr6rQ3bMjJ60dy8FvCtfGoagwZdwpAhV5KWFmm/4DuYEn8yuOCC0HVzzz3wl7+EqpkPPRSWQExPDytinXWWauZL0mlpqeXzz+9i9erHqKj4D9FozYZ9mZlF5ObuRSQymOzsQdTUvM/q1Y8SiQxkzz3viHXDDCU9PYeWlnoaGj6hru5jmppW0tS0mubmtfTqdRzdu49K4Bm2D3Pf2rdc4pSUlHhpaWmiw+g6XnwRjjgi1NSZMyfMuN1vP3jvPfjKV+D++9WXL11WU1MF6ek5W7S2m5urWL78byxd+v/R1LSSnJwR9Ox5FD17Hkl6ej41NfOoqZlHXd1CGho+oaGhjLS0bAYNupQBA35OenpOgs6o/ZjZbHcvievYeBK/mY0GbgTSgdvd/erN9keAu4GDgdXAae6+xMyOBq4GsoBG4Jfu/vz2fpYS/06IRuGQQ8KY/A8/hEgk3MC96ir46U9h0qQwekekiwj976WsWvUI5eWPUFc3HwCzCBkZ3XBvIRptJBqtA6L07Hksgwf/lh49tj86zb0F95aETppqbzuT+Hf4u7+ZpQM3A0cDZcAsM5vp7u+3OuxsYK27DzOz8cA1wGnAKuA77r7czPYFngH679zpyCZaWkL3DYTaOm+9BffdBzmxFszFF4e+/jTV35POxT1KNNqAexPRaCPuzUBIyLW1H7Jq1QxWrXqUxsblQDo9enyTvn3Pwr2F5uZKWlqqMEvHLEJaWjZFRWMoKBgZ188O70tv1/PrSuLp9B0JLHT3xQBmNhUYC7RO/GOBK2PPpwM3mZm5+9utjpkHZJtZxN0bdjnyVFNTE5Y+fP55OOooOPFEuOKK0OIfP37TY5X0JcHq6pawdu2zVFS8QF3dYhobl9HY+Hks2W9dWlouvXqNpqhoLIWFJ5CZ2asDI04t8ST+/sDSVq/LgEO3dYy7N5tZJVBIaPGvdzLw9taSvplNBCYCDBo0KO7gU0ZlZaip8/rrcPrpYSnExx8P+6ZOVaKXDhGNNtHQsIyMjO5kZHQnGq1n3brXqah4kaqqUpqb1xGN1tLUtJqGhpAysrL6kZe3N7m5RxKJ9CM9vYC0tEzMMjHL2PDIyupDjx5HJGXfe2cUT+Lf2ti/zW8MbPcYM9uH0P1zzNZ+gLtPBiZD6OOPI6bUsXo1jB4N77wTkvwpp4Sx+W+/DeXlqqIpbcrdaWxcQUZGQWz2aVBR8TIffTSR2toPY1ss9ogCaeTl7UdmZlFsJM0+FBSMpGfPo8nN3QvT8OFOJ57EXwYMbPV6ALB8G8eUmVkG0B1YA2BmA4BHgDPcfdEuR5wKnnkGnnwyVM18550w7v6RR+CE2Mw+MzjooMTGKF2CuxON1m+1Jd3SUkNNzfvU1MyNjYJ5l+rqd2lqKictLZuePY+hqOgk1q17nc8+m0x29hCGD7+JaLRpw7j3goKv0L3718jIKOj4k5MvLJ7EPwsYbmZDgWXAeOD0zY6ZCUwAXgfGAc+7u5tZD+AJ4FJ3f7Xtwk5SS5fC+eeHMfm5uVBSEsbon3aaFjeXneLurF37LEuWTGLdulfJyRlGQcFXyMvbh9raD1m3bha1te+z/hdzswh5eftSWPgd8vP3p65uEatWzWD16plAGgMHXsyQIVd2SOVIaX/xDuc8DriBMJxzirv/3swmAaXuPtPMsoF7gC8TWvrj3X2xmV0GXAosaPVxx7j7ym39rJQcztncHBY3v+yyMGrnqqvCuHwNxZQdqK8vY+XKB1i58n7q65eSk7M72dm7U1//MVVV/yUSGcBuu32furr5VFa+TlPTCjIzi2NlgA8hP/8A8vL2JSdn9y1Gvbg71dVvk56eT27uiG1EIJ1Fm4/j70gpl/j//e+Q5OfNC335t9wCQ4cmOirpJJqbq6mtnUdDw2c0Na2ksXEljY3LaWgoo77+U2pq5hC6XEaRl7c/9fUfU1e3GLMMBg78BX36TNgw+cndaW6uICOjh/rdk1CbjuOXdlJbC//zP6HvfvfdQ138MWNURyeFNTdXUlU1m6qqWVRVlVJd/Q51dYvYfCxFRkYvIpGBRCIDKC4+id12+35chcLMjMzMnu0UvXQlSvyJcsUVIen//vdw0UVh1q0ktbq6j/nss9uoq1tMTs4e5OQMJy0ti8rKV6msfJmamrmsT/LZ2buTn/9levc+g/z8/YlEBpKZuRtZWcVJVSxMEkOJPxFmzw6lFX70I/jNbxIdjbQDd6epqZy6ugXU1s46yyULAAANOUlEQVSnvHw6a9Y8DRjZ2UNYteqfGyYzpaXl0b37VykuHkdBwaF061ZCZmZhYk9AkpoSf3urrQ3F0w49NHTjNDWFRc1794Zrr010dLKTqqreprLyFYqKxpCdPRgIdWA+//xuPvnk9zQ1rYjVhWnaZJZqVlY/Bg++nL59zyE7ewDRaBP19Z8QjdaQm7sPaWn6rygdR//a2tO8eWHC1QcfwIEHwuWXh+fvvhsWR+nRI9ERymbCDdC11NcvoalpDWYZpKVlUlu7gOXL/05V1ZsALFx4IYWF36GoaAzLlv2V6up36NZtJEVFYzfUhcnK6kNOzghyc0eQnT1kk1EzaWmZXXIBD0kOSvzt5a674Nxzw0pXf/gDTJkC3/1u2HfyyXDSSYmNT2hsXEFFxYtUV79HXd1H1NZ+RH39Ilpaqrd6fG7uXgwbdgM9ex7FihUP8Nlnk2P13Qez995TKS4+VaNlpEvQcM625h5u1l5/PXzzm6Eeft++Yaz+gw+GGjvXXw99+iQ60qQWjTZRXf0uLS3riEbraGmppampnMbGFTQ2Lqey8rXYBCaAdHJydicnZ0RsHPxQsrOHkJlZFOu2aSYjo4Bu3Q7ZJLG3tNRTVVVKt24lpKdnJ+ZERWI0jj9R3ENJ5BtuCDNwr79+YwllaRPuLVRVvU19/ZJYP3oTZulkZPQgI6MHzc2VlJf/k1WrZtDcvGYrn2BkZhaSn38wPXt+ix49jiA//4CkrtMuqUHj+BPBPdTCv+GGUGbh+us1Jn8XtbTUxSYkLaSubgGVla9SUfECzc0V231feno3CgvHUFQ0lqys3UhLyyEtLSdWRKxYN1Il5el/QFuoqgrdO7fdtrGlr6S/VesX4wjdL1W0tFTT1LSampp5VFe/u8ks1ZaWqk3eG4kMpqjoZHr2PJK8vH0xyyQtLTO2UEdF7AvB6N796+p6EdkOJf5d9dhj4SbusmXw61/DH/+Yskk/Gm2mvn4J1dVvs27dG6xb9yYNDZ8SjdbHkn097o3bfH9GRg/y8vajoOBQsrJ2IzOzmOzsIeTkDCMnZ5gW5hBpI0r8X0RjYyibPHkyPPUU7LsvTJsGo0YlOrIO1dJSy+rVj1NePo3q6nepr/94w9h1swjduh1Ez55HxrpasklLi8Seh6Xz0tO7kZ7ejYyM7uTmfolIZIBGxYh0ACX+nVFXFxYwv/32sMB5nz6hhX/RRUlXSTMsev0WNTVzyM3dm/z8A0hPz6Gh4XPWrn2WNWueYtWqmUSjNWRl9aF798MoLh5HTs5w8vL2Iz9/f90wFemklPjjtWBBmIz17rthHP4PfwjHHAMZXfOvMJQUWE1Dw6c0NCylsXElLS3VtLRUU1e3kDVrnqapqXX17HQikQE0NHwCQGZmEb17n85uu32PHj2+oYWsRbqQrpm12lt5ebhBm5MDAweGlv6vfx1a9U88Accdl+gIdyiMY3+HdeveoKGhrFVJ3xU0Na2gsXHlNvvbMzIK6dXrWAoLj6NbtxJqaj6gqqqUurr55Of/mF69jiE//0DMtNavSFekxL+5tWtDS37OHIhGN24fNSpMwOqEi8E3N6+jsvJlamvnU1e3IJaoZxGN1gJglklWVm8yM4vJyupNXt6+ZGX1JhLpRyQyiEhkIFlZvWN97nmkpW3abZWbuyfFxScm4tREpB0o8bdWVRVa8++/H27eHn54GK2zejV8+csd1o/f2LiClpY6IpG+pKVFNiyAXVf3EU1N5axf6LqxcTmrVs2kouIF3JuAUKs9N3cEffueQ/fuX6Og4KtEIv1101RENlDiX6+2FsaOhVmzYPp0OPbYsH2PPcKjna2fcbpixb1UVLzI+rrsGRmFuDdss35MTs4IBgy4kMLC48nL209DHkVkh5T4AT78EE49FebOhXvvhRPbtlsjGm2msXE59fWf0ti4nGi0EfdmotEaqqvfZd26N2OLcETJyRnOkCFXEIkMii2xtwyzjFiVx+FkZYUaP+5ORkYBOTm7t2msIpL8lPjvvht+8hPIzQ3dO6NHt8nHurewevVTLF/+99gCHC1bPS4jo+eGcr6FhSdsUQhMRKStpWbiX7UqTLi691547bXQl3///dCv3w7f6u5Eo7Ubhj5WV79HZeXLVFa+QlNT+YZiYXV1i2ho+JSsrL4MGHAhubl7kZ09kKysfqSl5cTqvEfIyuqjRC8iHSq1Ev+qVfDLX4aE39wM++wThm2ef/4WVTSbm6upqZlDdfW71NS8R13dIurrP6a+/pMthkGaRSgoGElBwVdpaVlHc3MFeXn7MWzYdRQWjtlilIyISCKlRuJ3Dy36Cy+Eigo477wwAWu//Tapq9PUtJry8ocpL3+ItWufB8JwzoyMHuTkDCc//0CKik4kM7OI9PR80tPzyM7eg4KCQ7QAtoh0Gcmd+GtrwxKHt94Kr7wCo0bRMvmv1O+eF1rvy1+mrm5xbPWlBdTVLQRayMkZxqBBv6ag4Cvk5x9AJDJQ3TEikjSSJvG7O+6NYcWk8hX4n6+m7sX7qOpXQ/VR3ai9ZA/qei6lcfUhsHrj+9LScmL1ZfZlt91OpajoRPLzv6xELyJJK2kSf1NTOa+91nvjhhNiDyAjI4O8vL70zN6DnJw9yM4eumGJvays3io9ICIpJa7Eb2ajgRuBdOB2d796s/0R4G7gYEJ7+jR3XxLbdylwNmE848/c/Zk2i76V9KWrGPpoMfZ5OQwbgY09kezBh5KffxDZ2YPVghcRidlh4rdQdvFm4GigDJhlZjPd/f1Wh50NrHX3YWY2HrgGOM3M9gbGA/sA/YB/m9kId9/6oPZdkD5wGIMXjIQf/xiOPz5lF0MREdmRePo4RgIL3X2xh3GMU4Gxmx0zFrgr9nw6cKSFJvZYYKq7N7j7x8DC2Oe1vawsePxxOOEEJX0Rke2IJ/H3B5a2el0W27bVYzwswVQJFMb5XsxsopmVmllpeXl5/NGLiMhOiyfxb6357HEeE897cffJ7l7i7iXFxcVxhCQiIl9UPIm/DBjY6vUAYPm2jjGzDKA7sCbO94qISAeKJ/HPAoab2VAzyyLcrJ252TEzgQmx5+OA593dY9vHm1nEzIYCw4H/tk3oIiLyRexwVI+7N5vZT4FnCMM5p7j7PDObBJS6+0zgDuAeM1tIaOmPj713npk9BLwPNAPntceIHhERiZ+FhnnnUVJS4qWlpYkOQ0SkSzGz2e5eEs+xmrIqIpJilPhFRFJMp+vqMbNy4JNd+IgiYFUbhdOVpOp5g85d555atnXeg909rvHwnS7x7yozK423nyuZpOp5g85d555a2uK81dUjIpJilPhFRFJMMib+yYkOIEFS9bxB556qUvXcd/m8k66PX0REti8ZW/wiIrIdSvwiIikmaRK/mY02s/lmttDMLkl0PO3JzAaa2Qtm9oGZzTOzC2Lbe5nZs2a2IPZnz0TH2h7MLN3M3jazx2Ovh5rZm7HzfjBWTDDpmFkPM5tuZh/Grv1XUuia/zz2b32umT1gZtnJet3NbIqZrTSzua22bfU6W/CXWN6bY2YHxfMzkiLxt1oe8tvA3sD3Yss+Jqtm4CJ3/xIwCjgvdr6XAM+5+3DgudjrZHQB8EGr19cA18fOey1hKdBkdCPwtLvvBRxA+DtI+mtuZv2BnwEl7r4voVjk+iVek/G6/wMYvdm2bV3nbxOqHg8HJgJ/i+cHJEXiJ77lIZOGu3/m7m/FnlcREkB/Nl0C8y7gxMRE2H7MbABwPHB77LUB3yIs+QnJe94FwDcIlXBx90Z3ryAFrnlMBpATW+8jF/iMJL3u7v4fQpXj1rZ1nccCd3vwBtDDzPru6GckS+KPa4nHZGRmQ4AvA28Cvd39MwhfDsBuiYus3dwA/AqIxl4XAhWxJT8hea/97kA5cGesm+t2M8sjBa65uy8D/gx8Skj4lcBsUuO6r7et6/yFcl+yJP64lnhMNmaWD/wTuNDd1yU6nvZmZicAK919duvNWzk0Ga99BnAQ8Dd3/zJQQxJ262xNrD97LDAU6AfkEbo4NpeM131HvtC//2RJ/Cm3xKOZZRKS/n3u/nBs84r1v+bF/lyZqPjaydeAMWa2hNCd9y3CbwA9Yl0AkLzXvgwoc/c3Y6+nE74Ikv2aAxwFfOzu5e7eBDwMfJXUuO7rbes6f6HclyyJP57lIZNGrF/7DuADd7+u1a7WS2BOAB7t6Njak7tf6u4D3H0I4Ro/7+7fB14gLPkJSXjeAO7+ObDUzPaMbTqSsLJdUl/zmE+BUWaWG/u3v/7ck/66t7Kt6zwTOCM2umcUULm+S2i73D0pHsBxwEfAIuC3iY6nnc/1MMKvc3OAd2KP4wj93c8BC2J/9kp0rO34d/BN4PHY890JazkvBKYBkUTH107nfCBQGrvuM4CeqXLNgauAD4G5wD1AJFmvO/AA4V5GE6FFf/a2rjOhq+fmWN57jzDyaYc/QyUbRERSTLJ09YiISJyU+EVEUowSv4hIilHiFxFJMUr8IiIpRolfRCTFKPGLiKSY/x/vEVZTS6nhLgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(np.asarray(info_gains).mean(axis=0),'r')\n", "plt.plot(np.asarray(info_gain_rates).mean(axis=0),'y')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.决策树生成\n", "决策树的生成就是一个递归地调用特征选择的过程,首先从根节点开始,利用信息增益/信息增益比选择最佳的特征作为节点特征,由该特征的不同取值建立子节点,然后再对子节点调用以上方法,直到所有特征的信息增益/信息增益比均很小或者没有特征可以选择时停止,最后得到一颗决策树。接下来直接进行代码实现: " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('../')\n", "from ml_models import utils\n", "from ml_models.wrapper_models import DataBinWrapper\n", "\"\"\"\n", "ID3和C4.5决策树分类器的实现,放到ml_models.tree模块\n", "\"\"\"\n", "class DecisionTreeClassifier(object):\n", " class Node(object):\n", " \"\"\"\n", " 树节点,用于存储节点信息以及关联子节点\n", " \"\"\"\n", "\n", " def __init__(self, feature_index: int = None, target_distribute: dict = None, weight_distribute: dict = None,\n", " children_nodes: dict = None, num_sample: int = None):\n", " \"\"\"\n", " :param feature_index: 特征id\n", " :param target_distribute: 目标分布\n", " :param weight_distribute:权重分布\n", " :param children_nodes: 孩子节点\n", " :param num_sample:样本量\n", " \"\"\"\n", " self.feature_index = feature_index\n", " self.target_distribute = target_distribute\n", " self.weight_distribute = weight_distribute\n", " self.children_nodes = children_nodes\n", " self.num_sample = num_sample\n", "\n", " def __init__(self, criterion='c4.5', max_depth=None, min_samples_split=2, min_samples_leaf=1,\n", " min_impurity_decrease=0, max_bins=10):\n", " \"\"\"\n", " :param criterion:划分标准,包括id3,c4.5,默认为c4.5\n", " :param max_depth:树的最大深度\n", " :param min_samples_split:当对一个内部结点划分时,要求该结点上的最小样本数,默认为2\n", " :param min_samples_leaf:设置叶子结点上的最小样本数,默认为1\n", " :param min_impurity_decrease:打算划分一个内部结点时,只有当划分后不纯度(可以用criterion参数指定的度量来描述)减少值不小于该参数指定的值,才会对该结点进行划分,默认值为0\n", " \"\"\"\n", " self.criterion = criterion\n", " if criterion == 'c4.5':\n", " self.criterion_func = utils.info_gain_rate\n", " else:\n", " self.criterion_func = utils.muti_info\n", " self.max_depth = max_depth\n", " self.min_samples_split = min_samples_split\n", " self.min_samples_leaf = min_samples_leaf\n", " self.min_impurity_decrease = min_impurity_decrease\n", "\n", " self.root_node: self.Node = None\n", " self.sample_weight = None\n", " self.dbw = DataBinWrapper(max_bins=max_bins)\n", "\n", " def _build_tree(self, current_depth, current_node: Node, x, y, sample_weight):\n", " \"\"\"\n", " 递归进行特征选择,构建树\n", " :param x:\n", " :param y:\n", " :param sample_weight:\n", " :return:\n", " \"\"\"\n", " rows, cols = x.shape\n", " # 计算y分布以及其权重分布\n", " target_distribute = {}\n", " weight_distribute = {}\n", " for index, tmp_value in enumerate(y):\n", " if tmp_value not in target_distribute:\n", " target_distribute[tmp_value] = 0.0\n", " weight_distribute[tmp_value] = []\n", " target_distribute[tmp_value] += 1.0\n", " weight_distribute[tmp_value].append(sample_weight[index])\n", " for key, value in target_distribute.items():\n", " target_distribute[key] = value / rows\n", " weight_distribute[key] = np.mean(weight_distribute[key])\n", " current_node.target_distribute = target_distribute\n", " current_node.weight_distribute = weight_distribute\n", " current_node.num_sample = rows\n", " # 判断停止切分的条件\n", "\n", " if len(target_distribute) <= 1:\n", " return\n", "\n", " if rows < self.min_samples_split:\n", " return\n", "\n", " if self.max_depth is not None and current_depth > self.max_depth:\n", " return\n", "\n", " # 寻找最佳的特征\n", " best_index = None\n", " best_criterion_value = 0\n", " for index in range(0, cols):\n", " criterion_value = self.criterion_func(x[:, index], y)\n", " if criterion_value > best_criterion_value:\n", " best_criterion_value = criterion_value\n", " best_index = index\n", "\n", " # 如果criterion_value减少不够则停止\n", " if best_index is None:\n", " return\n", " if best_criterion_value <= self.min_impurity_decrease:\n", " return\n", " # 切分\n", " current_node.feature_index = best_index\n", " children_nodes = {}\n", " current_node.children_nodes = children_nodes\n", " selected_x = x[:, best_index]\n", " for item in set(selected_x):\n", " selected_index = np.where(selected_x == item)\n", " # 如果切分后的点太少,以至于都不能做叶子节点,则停止分割\n", " if len(selected_index[0]) < self.min_samples_leaf:\n", " continue\n", " child_node = self.Node()\n", " children_nodes[item] = child_node\n", " self._build_tree(current_depth + 1, child_node, x[selected_index], y[selected_index],\n", " sample_weight[selected_index])\n", "\n", " def fit(self, x, y, sample_weight=None):\n", " # check sample_weight\n", " n_sample = x.shape[0]\n", " if sample_weight is None:\n", " self.sample_weight = np.asarray([1.0] * n_sample)\n", " else:\n", " self.sample_weight = sample_weight\n", " # check sample_weight\n", " if len(self.sample_weight) != n_sample:\n", " raise Exception('sample_weight size error:', len(self.sample_weight))\n", "\n", " # 构建空的根节点\n", " self.root_node = self.Node()\n", "\n", " # 对x分箱\n", " self.dbw.fit(x)\n", "\n", " # 递归构建树\n", " self._build_tree(1, self.root_node, self.dbw.transform(x), y, self.sample_weight)\n", "\n", " # 检索叶子节点的结果\n", " def _search_node(self, current_node: Node, x, class_num):\n", " if current_node.feature_index is None or current_node.children_nodes is None or len(\n", " current_node.children_nodes) == 0 or current_node.children_nodes.get(\n", " x[current_node.feature_index]) is None:\n", " result = []\n", " total_value = 0.0\n", " for index in range(0, class_num):\n", " value = current_node.target_distribute.get(index, 0) * current_node.weight_distribute.get(index, 1.0)\n", " result.append(value)\n", " total_value += value\n", " # 归一化\n", " for index in range(0, class_num):\n", " result[index] = result[index] / total_value\n", " return result\n", " else:\n", " return self._search_node(current_node.children_nodes.get(x[current_node.feature_index]), x, class_num)\n", "\n", " def predict_proba(self, x):\n", " # 计算结果概率分布\n", " x = self.dbw.transform(x)\n", " rows = x.shape[0]\n", " results = []\n", " class_num = len(self.root_node.target_distribute)\n", " for row in range(0, rows):\n", " results.append(self._search_node(self.root_node, x[row], class_num))\n", " return np.asarray(results)\n", "\n", " def predict(self, x):\n", " return np.argmax(self.predict_proba(x), axis=1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#造伪数据\n", "from sklearn.datasets import make_classification\n", "data, target = make_classification(n_samples=100, n_features=2, n_classes=2, n_informative=1, n_redundant=0,\n", " n_repeated=0, n_clusters_per_class=1, class_sep=.5,random_state=21)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#训练查看效果\n", "tree = DecisionTreeClassifier(max_bins=15)\n", "tree.fit(data, target)\n", "utils.plot_decision_function(data, target, tree)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现,如果不对决策树施加一些限制,它会尝试创造很细碎的规则去使所有的训练样本正确分类,这无疑会使得模型过拟合,所以接下来需要对其进行减枝操作,避免其过拟合" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.决策树剪枝\n", "顾名思义,剪掉一些不必要的叶子节点,那么如何确定那些叶子节点需要去掉,哪些不需要去掉呢?这可以通过构建损失函数来量化,如果剪掉某一叶子结点后损失函数能减少,则进行剪枝操作,如果不能减少则不剪枝。一种简单的量化损失函数可以定义如下: \n", "\n", "$$\n", "C_\\alpha(T)=\\sum_{t=1}^{\\mid T\\mid}N_tH_t(T)+\\alpha\\mid T\\mid\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里$\\mid T \\mid$表示树$T$的叶结点个数,$t$是树$\\mid T \\mid$的叶结点,该叶节点有$N_t$个样本点,其中$k$类样本点有$N_{tk}$个,$k=1,2,3,...,K$,$H_t(T)$为叶结点$t$上的经验熵,$\\alpha\\geq 0$为超参数,其中: \n", "\n", "$$\n", "H_t(T)=-\\sum_k\\frac{N_{tk}}{N_t}log\\frac{N_{tk}}{N_t}\n", "$$ \n", "\n", "该损失函数可以分为两部分,第一部分$\\sum_{t=1}^{\\mid T\\mid}N_tH_t(T)$为经验损失,第二部分$\\mid T \\mid$为结构损失,$\\alpha$为调节其平衡度的系数,如果$\\alpha$越大则模型结构越简单,越不容易过拟合,接下来进行剪枝的代码实现: \n", "\n", "```python\n", " def _prune_node(self, current_node: Node, alpha):\n", " # 如果有子结点,先对子结点部分剪枝\n", " if current_node.children_nodes is not None and len(current_node.children_nodes) != 0:\n", " for child_node in current_node.children_nodes.values():\n", " self._prune_node(child_node, alpha)\n", "\n", " # 再尝试对当前结点剪枝\n", " if current_node.children_nodes is not None and len(current_node.children_nodes) != 0:\n", " # 避免跳层剪枝\n", " for child_node in current_node.children_nodes.values():\n", " # 当前剪枝的层必须是叶子结点的层\n", " if child_node.children_nodes is not None and len(child_node.children_nodes) > 0:\n", " return\n", " # 计算剪枝前的损失值\n", " pre_prune_value = alpha * len(current_node.children_nodes)\n", " for child_node in current_node.children_nodes.values():\n", " for key, value in child_node.target_distribute.items():\n", " pre_prune_value += -1 * child_node.num_sample * value * np.log(\n", " value) * child_node.weight_distribute.get(key, 1.0)\n", " # 计算剪枝后的损失值\n", " after_prune_value = alpha\n", " for key, value in current_node.target_distribute.items():\n", " after_prune_value += -1 * current_node.num_sample * value * np.log(\n", " value) * current_node.weight_distribute.get(key, 1.0)\n", "\n", " if after_prune_value <= pre_prune_value:\n", " # 剪枝操作\n", " current_node.children_nodes = None\n", " current_node.feature_index = None\n", "\n", " def prune(self, alpha=0.01):\n", " \"\"\"\n", " 决策树剪枝 C(T)+alpha*|T|\n", " :param alpha:\n", " :return:\n", " \"\"\"\n", " # 递归剪枝\n", " self._prune_node(self.root_node, alpha)\n", "```" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from ml_models.tree import DecisionTreeClassifier\n", "#训练查看效果\n", "tree = DecisionTreeClassifier(max_bins=15)\n", "tree.fit(data, target)\n", "tree.prune(alpha=1.5)\n", "utils.plot_decision_function(data, target, tree)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过探索$\\alpha$,我们可以得到一个比较令人满意的剪枝结果,这样的剪枝方式通常又被称为**后剪枝**,即从一颗完整生成后的树开始剪枝,与其对应的还有**预剪枝**,即在训练过程中就对其进行剪枝操作,这通常需要另外构建一份**验证集**做支持,这里就不实现了,另外比较通常的做法是,通过一些参数来控制模型的复杂度,比如`max_depth`控制树的最大深度,`min_samples_leaf`控制叶子结点的最小样本数,`min_impurity_decrease`控制特征划分后的最小不纯度,`min_samples_split`控制结点划分的最小样本数,通过调节这些参数,同样可以达到剪枝的效果,比如下面通过控制叶结点的最小数量达到了和上面剪枝一样的效果: " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tree = DecisionTreeClassifier(max_bins=15,min_samples_leaf=3)\n", "tree.fit(data, target)\n", "utils.plot_decision_function(data, target, tree)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 决策树另外一种理解:条件概率分布\n", "决策树还可以看作是给定特征条件下类的条件概率分布: \n", "\n", "(1)训练时,决策树会将特征空间划分为大大小小互不相交的区域,而每个区域对应了一个类的概率分布; \n", "\n", "(2)预测时,落到某区域的样本点的类标签即是该区域对应概率最大的那个类" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }