本文参考了周志华老师的《机器学习》(俗称“西瓜书”)。这里是 第四章“决策树” 的阅读笔记。本文专注于决策树的三种核心算法这一块,并且记录了我的思考,希望对你有所帮助🎉

1. 决策树简介

决策树是一种基于树结构的预测模型,常用于分类和回归任务。它通过一系列“条件判断”(树的分支)将数据划分为不同的子集,最终输出预测值(叶节点)。决策树的主要优点是直观易理解、不需要标准化特征,同时具有较高的解释性。

主要用途


2. 决策树的核心问题:划分标准

构建决策树的关键在于选择最佳的划分属性,也就是选择一个最优的特征将数据集分割为子集。划分标准的常见算法包括:

  1. 信息增益:基于熵的减少量选择最佳特征。
  2. 信息增益率:对信息增益进行归一化,解决信息增益对多值属性的偏好。
  3. 基尼指数:基于分类的不纯度选择划分特征。

以下分别对这三个划分标准进行详细介绍。


3. 信息增益 (Information Gain)

3.1 基本思想

信息增益衡量的是在某个特征 (A) 上划分数据集 (D) 后,数据的不确定性(信息熵)降低的程度。信息增益越大,说明这个特征对分类越重要。

3.2 信息熵
信息熵衡量了数据的不确定性,公式为:

$$ H(D) = -\sum_{k=1}^K p_k \log_2(p_k) $$

3.3 信息增益公式 / ID3算法
对于特征 A,其信息增益公式为:

$$ G(D, A) = H(D) - \sum_{v \in \text{Values}(A)} \frac{|D_v|}{|D|} H(D_v) $$

$$ \text{Values}(A) $$

是特征 A 的所有可能取值。

$$ |D_v| $$

是特征 A取值为 v 的子数据集大小。

$$ |D| $$

是数据集的总大小。

$$ H(D_v) $$

是子数据集 D_v的信息熵。

优点

缺点


4. 信息增益率 (Information Gain Ratio) / C4.5算法

4.1 基本思想

信息增益容易偏向于取值较多的特征,信息增益率通过引入一个归一化因子来克服这一问题。

4.2 公式
信息增益率定义为信息增益与特征的固有值之比:

$$ GR(D, A) = \frac{G(D, A)}{H_A(D)} $$

$$ H_A(D) = -\sum_{v \in \text{Values}(A)} \frac{|D_v|}{|D|} \log_2 \left( \frac{|D_v|}{|D|} \right) $$

优点

缺点


5. 基尼指数 (Gini Index) / CART算法

5.1 基本思想

基尼指数是另一种衡量数据纯度的指标,通常用于分类任务。它反映了随机选取两个样本,其类别不同的概率。

5.2 公式
数据集 D 的基尼指数定义为:

$$ Gini(D) = 1 - \sum_{k=1}^K p_k^2 $$

对于特征 A,按其不同取值划分后的基尼指数为:

$$ Gini_A(D) = \sum_{v \in \text{Values}(A)} \frac{|D_v|}{|D|} Gini(D_v) $$

特征 A 的基尼增益为:

$$ \Delta Gini = Gini(D) - Gini_A(D) $$

优点

缺点


6. Scikit-learn 实现决策树

以下代码使用 Scikit-learn 的 DecisionTreeClassifier,并展示如何选择不同的划分标准。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 1. 加载数据集
data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 2. 使用信息增益 (entropy) 构建决策树
tree_entropy = DecisionTreeClassifier(criterion='entropy', random_state=42)
tree_entropy.fit(X_train, y_train)
y_pred_entropy = tree_entropy.predict(X_test)

print("信息增益 (Entropy) 决策树准确率:", accuracy_score(y_test, y_pred_entropy))

# 3. 使用基尼指数 (gini) 构建决策树
tree_gini = DecisionTreeClassifier(criterion='gini', random_state=42)
tree_gini.fit(X_train, y_train)
y_pred_gini = tree_gini.predict(X_test)

print("基尼指数 (Gini) 决策树准确率:", accuracy_score(y_test, y_pred_gini))

# 4. 决策树可视化
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 8))
plot_tree(tree_entropy, feature_names=data.feature_names, class_names=data.target_names, filled=True)
plt.title("Decision Tree (Entropy)")
plt.show()

结果展示

信息增益 (Entropy) 决策树准确率: 0.9777777777777777
基尼指数 (Gini) 决策树准确率: 1.0

7. 头脑风暴

  1. 决策树三种算法的优缺点总结
算法 优点 缺点
ID3 - 使用信息增益选择特征,计算简单。
- 适合处理离散型特征。
- 构建的树相对直观易懂。
- 偏向于特征取值较多的属性(易过拟合)。
- 无法处理连续型特征。
- 对噪声数据较敏感。
C4.5 - 使用信息增益比选择特征,避免 ID3 偏向多取值属性问题。
- 支持连续型特征处理。
- 生成的树更简洁。
- 计算复杂度较高(需要对连续特征排序)。
- 易受样本数量少的类别影响。
- 生成的树可能过复杂。
CART - 使用基尼指数作为标准,计算高效。
- 支持分类和回归任务。
- 生成的树结果更稳定。
- 基尼指数可能偏向于特征取值较多的属性(与 ID3 类似)。
- 不适合多分类任务(需额外处理)。

总结要点

  1. ID3:以信息增益为准则,简单易实现,但对连续特征和多取值属性处理较差。

  2. C4.5:改进了 ID3 的不足,支持连续特征处理,避免多取值偏向,但计算复杂度较高。

  3. CART:分类回归都支持,计算高效稳定,但对多分类任务不够友好。

  4. 如何让决策树更好地应对特征关联性和不平衡数据?

背景

处理特征关联性

  1. 传统问题
    如果特征间有强烈相关性,决策树可能优先选择某些冗余特征,导致模型不稳定,分裂效果下降。

  2. 解决方法

    • 特征选择或降维
      • 在建树前通过统计方法移除高度相关的特征(如相关系数矩阵)。
      • 使用 PCA 等降维方法将高维相关特征转化为非相关的主成分。
    • 随机分裂特征(Random Feature Selection)
      • 随机森林通过限制分裂时的候选特征集合(如每次仅从所有特征中选择部分特征)有效解决了特征关联性问题。

应对不平衡数据

  1. 传统问题
    决策树对多数类偏向明显,叶节点的纯度更多依赖于多数类,导致少数类难以正确分类。

  2. 解决方法

    • 调整样本权重
      • 在分裂准则中为少数类样本赋予更高的权重,增加其对分裂点的影响。
    • 平衡采样
      • 使用上采样(重复少数类样本)或下采样(减少多数类样本)的方法来平衡类别分布。
    • 引入代价敏感学习
      • 为误分类的少数类样本赋予更高的惩罚代价,激励模型更关注少数类。

3.如何提升决策树的鲁棒性,特别是应对对抗样本和高维数据的复杂决策?

背景


应对对抗样本

  1. 传统问题
    决策树分裂过程基于固定规则,对小扰动不敏感,但集成模型(如随机森林)可能被对抗样本攻击。

  2. 解决方法

    • 增加分裂点随机性
      • 在分裂时引入随机性(如随机森林的随机特征选择),使模型对固定模式的攻击更鲁棒。
    • 对抗训练
      • 在训练时加入对抗样本,增强模型对未知样本的适应性。
    • 加权叶节点输出
      • 根据叶节点样本分布对输出进行加权,减少单一决策路径的影响。

优化高维数据的决策逻辑

  1. 传统问题
    高维数据导致树的深度和复杂度显著增加,叶节点数量激增,不易解释模型决策逻辑。

  2. 解决方法

    • 剪枝技术(Pruning)
      • 使用预剪枝或后剪枝方法减少树的复杂性。
      • 例如,当节点分裂不能显著提高模型性能时,停止分裂。
    • 特征降维
      • 使用特征选择或降维方法减少输入维度,如 L1 正则化筛选特征。
    • 规则提取
      • 将决策树简化为一组易理解的规则集合,聚合重要分裂路径。

文章参考