8.4 决策树剪枝过程#

8.4.1 剪枝思想#

在8.3节内容中我们介绍过,使用ID3算法进行构建决策树时容易产生过拟合现象,因此需要使用一种方法来缓解这一现象。通常,决策树过拟合的表现形式为这棵树有很多叶子节点。想象一下,如果这棵树为每个样本点都生成一个叶节点,也就代表着这棵树能够拟合所有的样本点,因为决策树的每个叶节点都表示一个分类类别。同时,出现过拟合的原因在于模型在学习时过多地考虑如何提高对训练数据的正确分类,从而构建出过于复杂的决策树,因此,解决这一问题的办法就是考虑减少决策树的复杂度,即对已经生成的决策树进行简化,也就是剪枝(Pruning)。

8.4.2 剪枝步骤#

决策树的剪枝往往通过最小化决策树整体的损失函数或者代价函数实现。设树$T$的叶节点个数为$|T|$,$t$是树$T$的一个叶节点,该叶节点有$N_t$个样本点,其中类别$k$的样本点有$N_{tk}$个,其中$k=1,2,...,K$。同时,$H_t(T)$为叶节点$t$上的经验熵,$\alpha \geq 0$为参数,则决策树的损失函数可以定义为

$$ {{C}_{\alpha }}(T)=\sum\limits_{t=1}^{|T|}{{{N}_{t}}}{{H}_{t}}(T)+\alpha |T|\tag{8-27} $$

其中经验熵为

$$ {{H}_{t}}(T)=-\sum\limits_{k}^{K}{\frac{{{N}_{tk}}}{{{N}_{t}}}}\log \frac{{{N}_{tk}}}{{{N}_{t}}}\tag{8-28} $$

进一步令

$$ C(T)=\sum\limits_{t=1}^{|T|}{{{N}_{t}}}{{H}_{t}}(T)=-\sum\limits_{t=1}^{|T|}{\sum\limits_{k=1}^{K}{{{N}_{tk}}}}\log \frac{{{N}_{tk}}}{{{N}_{t}}}\tag{8-29} $$

此时损失函数可以写为

$$ {{C}_{\alpha }}(T)=C(T)+\alpha |T|\tag{8-30} $$

其中$C(T)$表示模型对训练数据的分类误差,即模型与训练集的拟合程度,本质上就是所有叶子结点总的信息熵;$|T|$表示模型复杂度,参数$\alpha\geq0$用于控制两者之间的平衡。此时可以发现,较大的$\alpha$促使选择较简单的模型(树),较小的$\alpha$促使选择较复杂的模型(树),而$\alpha=0$则意味着只考虑模型与训练集的拟合程度,而不考虑模型的复杂度,因此,这里$\alpha$的作用就类似于正则化中惩罚系数。

具体地,决策树的剪枝步骤如下

输入: 生成算法产生的整棵树$T$,参数$\alpha$。

输出: 修剪后的子树$T_{\alpha}$

(1) 计算每个叶节点的经验(信息)熵。

(2) 递归地从树的叶节点往上回溯,设一组叶节点回溯到其父节点之前与之后的整体树分别为$T_B$和$T_A$,其对应的损失函数值分别是$C_{\alpha}(T_B)$和$C_{\alpha}(T_A)$,如果$C_{\alpha}(T_A)\leq C_{\alpha}(T_B)$,则进行剪枝,即将父节点变为新的叶节点。

(3) 返回步骤(2),直到不能继续为止,得到损失函数最小的子树$T_{\alpha}$。

当然,如果仅看这些步骤依旧会很模糊,下面再来通过一个实际计算示例进行说明。

8.4.3 剪枝示例#

如图8-9所示,在考虑是否要减掉“学历等级”这个节点时,首先需要计算的就是剪枝前的损失函数数值$C_\alpha(T_B)$。由于剪枝时,每次只考虑一个节点,所以在计算剪枝前和剪枝后的损失函数值时,仅考虑该节点即可。因为其他叶节点的经验熵对于剪枝前和剪枝后都没有变化。

图 8-9 决策树剪枝

根据表8-4可知,“学历等级”这个节点对应的训练数据如表8-5所示。

表 8-5 学历等级样本分布表

根据式(8-29)有

$$ C({{T}_{B}})=-\sum\limits_{t=1}^{2}{\sum\limits_{k=1}^{2}{{{N}_{tk}}}}\log \frac{{{N}_{tk}}}{{{N}_{t}}}=-\left[ \left(2{{\log }_{2}}\frac{2}{2}+0\right)+\left(1{{\log }_{2}}\frac{1}{2}+1{{\log }_{2}}\frac{1}{2}\right) \right]=2\tag{8-31} $$

进一步,根据式(8-30)有

$$ {{C}_{\alpha }}({{T}_{B}})=C({{T}_{B}})+\alpha |{{T}_{B}}|=2+2\alpha\tag{8-32} $$

29