决策树数据构建和训练示例
1 简介
作为最流行的经典机器学习算法之一,决策树的可解释性比其他算法更直观。
CART算法经常用于构建决策树模型,它可能也是最常用的算法。
当我们将 Scikit-Learn 库用于决策树分类器时,它是默认算法。
用 scikit-learn 构建一个 CART 决策树模型;
计算每个叶节点上“流失”类别的概率;
根据概率阈值为每个叶节点打上模态标签:
□ churn(必然流失,概率 > 0.9)
□ not churn(必然不流失,概率 < 0.1)
◇ churn/◇ not churn(可能流失/不流失,0.1 ≤ 概率 ≤ 0.9)
将每个样本映射到对应叶节点,并附加该节点的模态标签。
2 叶子节点模态标签
你可以在“Leaf Node Modal Labels”表格中看到,各叶节点的 chanr概率 及对应的 modal_label。
- 样本级别预测
下表展示了每个样本的特征、真实标签、落入的叶子节点以及该节点的模态判断:
age monthly_spend churn leaf_id probability_of_churn modal_label
25 1200 1 1 1.0 □ churn
27 1500 1 1 1.0 □ churn
35 800 0 2 0.0 □ not churn
… … … … … …
这样,一方面你有传统的“分类”结果,另一方面你得到 模态逻辑层面的不确定性标注,便于解释与决策支持。
后续:
你可以调整概率阈值,或者用更细粒度的模态算子(比如多级可能性)。
在生产环境中,还能结合知识图谱,对模态公式添加先验约束,提升推理的严密性。
使用python实现模态模拟
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
3 构造和训练
-
示例数据集
data = { 'age': [25, 27, 35, 22, 45, 30, 28, 40, 50, 23], 'monthly_spend': [1200, 1500, 800, 900, 700, 1100, 1300, 950, 600, 1000], 'churn': [1, 1, 0, 1, 0, 1, 1, 0, 0, 1] } df = pd.DataFrame(data) X = df[['age', 'monthly_spend']] y = df['churn']
-
训练 CART 决策树
clf = DecisionTreeClassifier(max_depth=3, random_state=42) clf.fit(X, y)
-
提取每个样本对应的叶子节点
leaf_indices = clf.apply(X)
-
计算每个叶子节点上的类别概率和模态标签
tree = clf.tree_ leaf_nodes = np.where(tree.children_left == -1)[0] leaf_info = [] for node in leaf_nodes: counts = tree.value[node][0] total = counts.sum() prob_pos = counts[1] / total # 确定预测类别及模态标签 if prob_pos > 0.9: modal = '□ churn' pred_class = 1 elif prob_pos < 0.1: modal = '□ not churn' pred_class = 0 else: if prob_pos >= 0.5: modal = '◇ churn' pred_class = 1 else: modal = '◇ not churn' pred_class = 0 leaf_info.append({ 'leaf_id': node, 'predicted_class': pred_class, 'probability_of_churn': round(prob_pos, 2), 'modal_label': modal }) leaf_df = pd.DataFrame(leaf_info) import ace_tools as tools; tools.display_dataframe_to_user(name="Leaf Node Modal Labels", dataframe=leaf_df)
-
样本映射到对应的模态标签
sample_info = df.copy()
sample_info[‘leaf_id’] = leaf_indices
sample_info = sample_info.merge(leaf_df, on=‘leaf_id’, how=‘left’)
sample_info
4 小结
CART 算法是 Classification And Regression Trees 的缩写。它是由 Breiman 等人于 1984 年发明的。
它通常与 C4.5 非常相似,但具有以下主要特征:
CART 不是可以有多个分支的通用树,而是使用二叉树,每个节点只有两个分支。
CART 使用 Gini Impurity 作为拆分节点的标准,而不是 Information Gain。
CART 支持数字目标变量,这使它本身能够成为预测连续值的回归树。
本文重点介绍了作为分类树的 CART。
就像依赖信息增益作为拆分节点的标准的 ID3 和 C4.5 算法一样,CART 算法使用另一个称为 Gini 的标准来拆分节点.
- 点赞
- 收藏
- 关注作者
评论(0)