1. 决策树简介
决策树是一种监督学习算法,它可以对离散或连续的数据进行分类或回归分析。在决策树中,每个分支代表一个属性,每个叶子节点代表一个类别。由于决策树可以将数据分割成一系列的子集,使得每个子集内的实例尽可能属于同一类别,因此决策树也被称为“分裂损失最小的属性”。
决策树算法的主要优点是易于实现和解释,具有较高的分类性能。在应用领域中,决策树可以用于医学诊断、信用评估、风险分析、市场营销、图像识别等方面。
2. 决策树的构建过程
2.1 特征选择
在构建决策树的过程中,选择合适的属性进行分割是一个关键步骤。常见的特征选择方法包括信息增益、信息增益比、基尼指数等。
信息增益(ID3算法):信息增益是指在决策树中用某个属性分割样本集合,而得到的信息增强的大小,选取信息增益最大的属性作为划分标准。公式如下:
def calc_information_gain(X, y, feature_idx):
# 计算信息熵
entropy = calc_entropy(y)
# 根据属性划分数据集
X_left, y_left, X_right, y_right = split_data_by_feature(X, y, feature_idx)
# 计算子节点的熵
entropy_left = calc_entropy(y_left)
entropy_right = calc_entropy(y_right)
# 计算信息增益
info_gain = entropy - (len(y_left) / len(y) * entropy_left) - (len(y_right) / len(y) * entropy_right)
return info_gain
2.2 决策树的生成
根据上一步的特征选择方法,我们可以选择一个特征作为当前节点。对于此节点,根据节点的取值进行数据集划分,然后再针对每个子节点进行递归构建,直到每个子节点都为叶子节点。在构建过程中,需要判断何时停止递归,即如何判断节点为叶子节点。通常的方法有以下几种:
当所有的实例都属于同一类别时,停止递归,将此节点标记为叶子节点。
当划分的子数据集为空时,停止递归,将此节点标记为叶子节点。
当节点的深度达到预定的值时,停止递归,将此节点标记为叶子节点。
def build_decision_tree(X, y, max_depth=None):
# 判断是否停止递归
if len(y) == 0:
return None
if len(set(y)) == 1:
return DecisionNode(label=y[0])
if max_depth is not None and max_depth == 0:
return DecisionNode(label=get_majority_label(y))
# 选择最佳的属性划分数据集
best_feature_idx = choose_best_feature(X, y)
# 生成节点
node = DecisionNode(feature_idx=best_feature_idx)
# 递归生成左右节点
X_left, y_left, X_right, y_right = split_data_by_feature(X, y, best_feature_idx)
node.set_childs(build_decision_tree(X_left, y_left, max_depth-1), build_decision_tree(X_right, y_right, max_depth-1))
return node
3. 决策树的剪枝
决策树生成时,为了避免过拟合,我们可以通过剪枝策略来优化生成的决策树结构,以达到提高泛化能力的目的。常见的剪枝方法包括预剪枝和后剪枝两种。
3.1 预剪枝
预剪枝是指在决策树生成过程中,在递归生成节点时预先设定一些停止条件,若满足此条件,则返回当前节点并标记为叶子节点。预剪枝可以避免决策树的过拟合,但是可能会使决策树的预测性能下降。
def build_decision_tree(X, y, max_depth=None, min_samples_split=2):
...
# 预剪枝
if len(y) < min_samples_split or get_majority_label(y) >= 0.9:
return DecisionNode(label=get_majority_label(y))
...
3.2 后剪枝
后剪枝是指在决策树生成完成后,对已生成的决策树进行剪枝。通常,我们会将生成的决策树分为训练集和测试集两部分,对测试集进行剪枝,以减少决策树复杂度,提高泛化性能。
def post_pruning(decision_tree, X_test, y_test):
if decision_tree is None:
return None
# 如果为叶子节点,则返回本身
if decision_tree.is_leaf_node():
return decision_tree
# 递归后剪枝左右子树
left_subtree = post_pruning(decision_tree.left_child, X_test, y_test)
right_subtree = post_pruning(decision_tree.right_child, X_test, y_test)
decision_tree.set_childs(left_subtree, right_subtree)
# 计算决策树在测试集上的精度
accuracy = get_accuracy(decision_tree, X_test, y_test)
# 计算不剪枝的精度
non_pruned_accuracy = get_accuracy(decision_tree, X_test, y_test, prune=False)
# 如果剪枝后精度没有下降,则进行剪枝处理
if accuracy >= non_pruned_accuracy:
decision_tree.left_child = decision_tree.right_child = None # 剪枝
decision_tree.label = get_majority_label(y_test) # 将剪枝后的叶子节点标记为样本集中的majority类别
return decision_tree
4. 总结
本文主要介绍了决策树算法的构建方法以及剪枝策略,决策树作为一种监督学习算法,具有易解释、易实现和较高分类性能等特点,被广泛应用于医学诊断、风险分析、市场营销、图像识别等领域。在构建决策树过程中,特征选择是一个关键步骤,常用的方法有信息增益、信息增益比、基尼指数等。在剪枝策略中,预剪枝和后剪枝是常用的方法,它们可以避免决策树的过拟合,提高模型的泛化能力。