8.8 从零实现CART算法及剪枝示例#

在8.7节内容中我们已经详细介绍了CART分类树的构建原理和剪枝过程,在本节内容中将开始介绍如何从零开始编码实现一个简易版的CART分类树。同时,由于在实际场景中数据集的特征维度大多都是连续型的特征变量,因此需要对其进行离散化处理,后续也将以连续型特征为例进行编码实现。具体离散化方法见8.6.1节内容,这里不再赘述。

下面,将直接通过表8-6中的样本数据来详细介绍在CART分类树中如何离散化连续型特征并构造相应的决策树。

8.8.1 连续型特征生成示例#

在处理连续型特征变量时,第一步便是需要将各个维度的特征进行离散化,然后再根据离散化后的区间来对特征进行判断。具体地,表8-6中对应的3个特征在离散化后各特征的取值分割点分别为: $[0.5]$ 、 $[0.5]$ 、$ [0.5,1.5]$。

由表8-6中的数据可知,根据式(8-35)可得,此时的其基尼不纯度为

$$ \text{Gini}(D)=1-\sum\limits_{k=1}^{K}{{{\left( \frac{|{{C}_{k}}|}{|D|} \right)}^{2}}}=1-{{\left( \frac{5}{15} \right)}^{2}}-{{\left( \frac{10}{15} \right)}^{2}}\approx 0.444\tag{8-51} $$

并且对于特征${{A}_{1}}$来说,根据其取值是否满足条件$A_1\leq0.5$,可以将原始样本划分为$D_1$和$D_2$两个部分。由式(8-37)得

$$ \text{Gini}(D,{A}_{1}\leq0.5)=\frac{7}{15}\text{Gini}({{D}_{1}})+\frac{8}{15}\text{Gini}({{D}_{2}})=\frac{7}{15}\times \frac{24}{49}+\frac{8}{15}\times \frac{14}{64}\approx 0.345\tag{8-52} $$

同理,对于特征${{A}_{2}}$来说,根据其取值是否满足条件$A_2\leq0.5$,也可以将原始样本划分为$D_1$和$D_2$两个部分。此时有

$$ \text{Gini}(D,{A}_{2}\leq0.5)=\frac{8}{15}\text{Gini}({{D}_{1}})+\frac{7}{15}\text{Gini}({{D}_{2}})=\frac{8}{15}\times \frac{1}{2}+\frac{7}{15}\times \frac{12}{49}\approx 0.381\tag{8-53} $$

进一步,对于特征${{A}_{3}}$来说,根据其取值是否分别满足条件$A_3\leq0.5$和$A_3\leq1.5$,每一次也可将原始样本划分为$D_1$和$D_2$两个部分。此时有

$$ \begin{aligned} \text{Gini}(D,{A}_{3}\leq0.5)&=\frac{12}{15}\text{Gini}({{D}_{1}})+\frac{3}{15}\text{Gini}({{D}_{2}})=\frac{12}{15}\times \frac{4}{9}+\frac{3}{15}\times \frac{4}{9}\approx 0.444\\[2ex] \text{Gini}(D,{A}_{3}\leq1.5)&=\frac{7}{15}\text{Gini}({{D}_{1}})+\frac{8}{15}\text{Gini}({{D}_{2}})=\frac{7}{15}\times \frac{20}{49}+\frac{8}{15}\times \frac{30}{64}\approx 0.441 \end{aligned}\tag{8-54} $$

注意:此时每次划分时都是将样本集合划分为两个部分,即$A_i \leq a$和$A_i> a$。

由以上计算结果可知,使用${A}_{1}\leq0.5$时对样本集合进行划分所得到的基尼不纯度最小。故,根节点应该以${A}_{1}\leq 0.5$​是否成立来进行分割,如图8-21所示(在每个节点中,第1行表示当前节点的划分特征维度以及样本的索引,第2行分别表示判断区间和每个类别中的各个样本的数量)。

图 8-21. CART第一次划分

从图8-21可以看出,根据特征$A_1$是否存满足条件$A_1\leq0.5$可以将原始样本划分为$D_1$和$D_2$两个部分。经过这次划分后,原始的样本集合就被特征“有工作”分割成了左右两个部分。接下来,再对左右两个集合递归的进行上述步骤。

1. 对于左子$D_1$树来说

$$ \text{Gini}(D_1)=1-(\frac{4}{7})^2-(\frac{3}{7})^2\approx0.490\tag{8-55} $$

在特征$A_2$中有

$$ \text{Gini}(D_1, A_2\leq0.5)=\frac{3}{7}(1-1)+\frac{4}{7}(1-\frac{1}{16}-\frac{9}{16})\approx0.214\tag{8-56} $$

在特征$A_3$中有

$$ \begin{aligned} \text{Gini}(D_1,A_3\leq 0.5)&=\frac{1}{7}(1-1)+\frac{6}{7}(1-\frac{1}{4}-\frac{1}{4})\approx0.4286\\[2ex] \text{Gini}(D_1, A_3\leq 1.5)&=\frac{3}{7}(1-\frac{1}{9}-\frac{4}{9})+\frac{4}{7}(1-\frac{1}{16}-\frac{9}{16})\approx0.405\\[2ex] \end{aligned}\tag{8-57} $$

此时,基尼不纯度$\text{Gini}(D_1, A_2\leq 0.5)\approx0.214$最小,故选择$ A_2\leq 0.5$作为判断区间,此时划分结果如表8-7所示。

表 8-7.信用卡审批数据划分结果$D_1$表

2. 对于右子$D_2$树来说

$$ \text{Gini}(D_2)=1-(\frac{1}{8})^2-(\frac{7}{8})^2\approx0.219\tag{8-58} $$

在特征$A_2$中有

$$ \text{Gini}(D_2, A_2\leq0.5)=\frac{5}{8}(1-\frac{1}{25}-\frac{16}{25})+\frac{3}{8}(1-1)\approx0.2\tag{8-59} $$

在特征$A_3$中有

$$ \begin{aligned} \text{Gini}(D_2,A_3\leq 0.5)&=\frac{2}{8}(1-1)+\frac{6}{8}(1-\frac{1}{36}-\frac{25}{36})\approx0.208\\[2ex] \text{Gini}(D_2,A_3\leq 1.5)&=\frac{4}{8}(1-\frac{1}{16}-\frac{9}{16})+\frac{4}{8}(1-1)\approx0.188\\[2ex] \end{aligned}\tag{8-60} $$

此时,基尼不纯度$\text{Gini}(D_2, A_3\leq1.5)\approx0.188$最小,故选择 $A_3\leq1.5$作为判断区间,此时划分结果如表8-8所示。

表 8-8. 信用卡审批数据划分结果$D_2$表

进一步,根据表8-7和表8-8的划分结果便可以得到如图8-22所示的结果。

图 8-22. CART第二次划分

根据图8-22可知,此时对于集合$D_{11}$来说其只有一个类别,因此停止划分。

3. 对于集合$D_{12}$来说

$$ \text{Gini}(D_{12})=1-(\frac{1}{4})^2-(\frac{3}{4})^2\approx0.375\tag{8-61} $$

在特征$A_3$​中有

$$ \text{Gini}(D_{12}, A_3\leq 1.5)=\frac{2}{4}(1-1)+\frac{2}{4}(1-\frac{1}{4}-\frac{1}{4})=0.25\tag{8-62} $$

此时,基尼不纯度$\text{Gini}(D_{12},A_3\leq1.5)=0.25$最小,故选择$A_3\leq1.5$作为判断区间。

4. 对于集合$D_{21}$来说

$$ \text{Gini}(D_{12})=1-(\frac{1}{4})^2-(\frac{3}{4})^2\approx0.375\tag{8-63} $$

在特征$A_2$中有

$$ \begin{aligned} \text{Gini}(D_{21},A_2\leq0.5)=\frac{3}{4}(1-\frac{1}{9}-\frac{4}{9})+\frac{1}{4}(1-1)\approx0.333 \end{aligned}\tag{8-64} $$

此时,基尼不纯度$\text{Gini}(D_{21},A_2\leq0.5)\approx0.333$最小,故选择$A_2\leq0.5$作为判断区间。当然,对于特征$A_2$来说其只有一个取值区间,所以也可以不用计算。

对于集合$D_{22}$​来说其只有一个类别,所以也停止划分。这样便可以得到如表8-9所示的划分结果。

表 8-9. 信用卡审批数据划分结果表

最终,根据表8-9便可得到如图8-23所示的决策树。

图 8-23. CART分类决策树

这里值得一提的是,由于划分特征的处理方式不同,所以图8-23中的结果与图8-14中的结果略有差异。

8.8.2 连续型特征剪枝示例#

在简单介绍完连续型特征变量的CART分类树生成示例后,我们再以此生成结果为例来详细介绍整个剪枝过程。由图8-23可知,此时的决策树便是子树序列中的$T_0$​,且可以简化为如图8-24所示的结果。

图 8-24. 子树序列$T_0$

其中叶子节点中的数字表示每个类别的样本数。

此时,根据$T_0$可以构建$T_1$中 的各种剪枝情况,然后选择最优节点进行剪枝。对于$T_0$​来说,第1种情况为减掉5号节点,如图8-25所示。

图 8-25 子树序列$T_0,g(t_0)$

根据图8-24和图8-25可知,由式(8-45)和式(8-46)可分别计算得到剪枝前$C(T_t)$和剪枝后$C(t)$的损失分别为

$$ \begin{aligned} C(T_t)&=-(1\cdot\log\frac{1}{3}+2\cdot\log\frac{2}{3}+1\cdot\log1)\approx2.7549\\[2ex] C(t)&=-(1\cdot\log\frac{1}{4}+3\cdot\log\frac{3}{4})\approx3.2451 \end{aligned}\tag{8-65} $$

此时有

$$ g(t_0)=\frac{C(t)-C(T_t)}{|T_t|-1}=\frac{3.2451-2.7549}{2-1}\approx0.4902\tag{8-66} $$

对于$T_0$​来说,第2种情况为减掉4号节点,如图8-26所示。

图 8-26 子树序列$T_0,g(t_1)$

根据图8-23和图8-25可知,在剪枝前$C(T_t)$和剪枝后$C(t)$的损失以及$g(t_1)$分别为

$$ \begin{aligned} C(T_t)&=-(2\cdot\log\frac{2}{2}+1\cdot\log\frac{1}{2}+1\cdot\log\frac{1}{2})=2\\[2ex] C(t)&=-(1\cdot\log\frac{1}{4}+3\cdot\log\frac{3}{4})\approx3.2451\\[2ex] g(t_1)&=\frac{3.245-2}{2-1}=1.2451 \end{aligned}\tag{8-67} $$

对于$T_0$​来说,第3种情况为减掉3号节点,如图8-27所示。

图 8-27 子树序列$T_0,g(t_2)$

根据图8-24和图8-27可知,在剪枝前$C(T_t)$和剪枝后$C(t)$的损失以及$g(t_2)$分别为

$$ \begin{aligned} C(T_t)&=-(1\cdot\log\frac{1}{3}+2\cdot\log\frac{2}{3}+1\cdot\log\frac{1}{1}+4\cdot\log\frac{4}{4})\approx2.7549\\[2ex] C(t)&=-(1\cdot\log\frac{1}{8}+7\cdot\log\frac{7}{8})\approx4.3485\\[2ex] g(t_2)&=\frac{4.3485-2.7549}{3-1}=0.7968 \end{aligned}\tag{8-68} $$

对于$T_0$​来说,第4种情况为减掉2号节点,如图8-28所示。

图 8-28 子树序列$T_0,g(t_3)$

根据图8-24和图8-28可知,在剪枝前$C(T_t)$和剪枝后$C(t)$的损失以及$g(t_3)$分别为

$$ \begin{aligned} C(T_t)&=-(3\cdot\log\frac{3}{3}+2\cdot\log\frac{2}{2}+1\cdot\log\frac{1}{2}+1\cdot\log\frac{1}{2})=2\\[2ex] C(t)&=-(4\cdot\log\frac{4}{7}+3\cdot\log\frac{3}{7})\approx6.8966\\[2ex] g(t_3)&=\frac{6.897-2}{3-1}\approx2.4483 \end{aligned}\tag{8-69} $$

8.8.3 节点定义实现#

8.8.4 基尼不纯度实现#

8.8.5 决策树构建实现#

8.8.6 决策树预测实现#

8.8.7 决策树剪枝实现#

8.8.8 使用示例#

8.8.9 小结#

36