在CART决策树中使用分类功能
1 简介
决策树中的 CART 分类回归树
CART( Classification And Regression Trees 分类与回归树)模型是一种用于分类和回归任务的决策树算法,是决策树算法的一种变体。
它可以处理分类和回归任务。 Scikit-Learn 使用分类与回归树 (CART) 算法来训练决策树(也称为“增长”树)。
CART 于 1984 年由 Leo Breiman、Jerome Friedman、Richard Olshen 和 Charles Stone 首次制作。
它由Breiman等人于1986年提出,广泛应用于数据挖掘和机器学习中。
2 归类树实现示例
包括以下功能:
分类功能
移除所有剪枝参数和复杂配置
使用最简单的基尼不纯度作为分裂标准
核心保留:
完整的树构建过程
最佳分裂点寻找
预测功能
代码精简:
约100行核心代码
直观的数据结构
清晰的递归构建过程
示例数据:
使用最简单的2D分类数据
包含3个测试用例
package main
import (
"fmt"
)
// 树节点结构
type TreeNode struct {
IsLeaf bool
Class int // 叶子节点存储的类别
SplitIndex int // 分裂特征索引
SplitValue float64 // 分裂值
Left *TreeNode // 左子树
Right *TreeNode // 右子树
}
// 极简CART分类树
type SimpleCart struct {
Root *TreeNode
}
// 训练入口
func (c *SimpleCart) Fit(features [][]float64, labels []int) {
c.Root = c.buildTree(features, labels)
}
// 递归构建树
func (c *SimpleCart) buildTree(features [][]float64, labels []int) *TreeNode {
// 如果所有标签相同,创建叶子节点
if allSame(labels) {
return &TreeNode{IsLeaf: true, Class: labels[0]}
}
// 寻找最佳分裂
bestIdx, bestVal := c.findBestSplit(features, labels)
if bestIdx == -1 {
return &TreeNode{IsLeaf: true, Class: majorityClass(labels)}
}
// 分割数据
leftFeat, leftLab, rightFeat, rightLab := splitData(features, labels, bestIdx, bestVal)
// 递归构建子树
node := &TreeNode{
IsLeaf: false,
SplitIndex: bestIdx,
SplitValue: bestVal,
Left: c.buildTree(leftFeat, leftLab),
Right: c.buildTree(rightFeat, rightLab),
}
return node
}
// 寻找最佳分裂
func (c *SimpleCart) findBestSplit(features [][]float64, labels []int) (int, float64) {
bestGini := 1.0
bestIdx := -1
bestVal := 0.0
for i := 0; i < len(features[0]); i++ {
// 获取当前特征的所有唯一值
values := getUniqueValues(features, i)
for _, v := range values {
// 计算基尼指数
gini := c.calculateGini(features, labels, i, v)
if gini < bestGini {
bestGini = gini
bestIdx = i
bestVal = v
}
}
}
return bestIdx, bestVal
}
// 计算基尼指数
func (c *SimpleCart) calculateGini(features [][]float64, labels []int, splitIdx int, splitVal float64) float64 {
leftCounts := make(map[int]int)
rightCounts := make(map[int]int)
totalLeft := 0
totalRight := 0
for i, sample := range features {
if sample[splitIdx] <= splitVal {
leftCounts[labels[i]]++
totalLeft++
} else {
rightCounts[labels[i]]++
totalRight++
}
}
// 计算左右子树的基尼不纯度
giniLeft := 1.0
for _, count := range leftCounts {
p := float64(count) / float64(totalLeft)
giniLeft -= p * p
}
giniRight := 1.0
for _, count := range rightCounts {
p := float64(count) / float64(totalRight)
giniRight -= p * p
}
// 加权平均
total := float64(totalLeft + totalRight)
return (float64(totalLeft)/total)*giniLeft + (float64(totalRight)/total)*giniRight
}
// 预测函数
func (c *SimpleCart) Predict(feature []float64) int {
node := c.Root
for !node.IsLeaf {
if feature[node.SplitIndex] <= node.SplitValue {
node = node.Left
} else {
node = node.Right
}
}
return node.Class
}
// --- 辅助函数 ---
// 检查所有标签是否相同
func allSame(labels []int) bool {
if len(labels) == 0 {
return true
}
first := labels[0]
for _, label := range labels {
if label != first {
return false
}
}
return true
}
// 获取多数类别
func majorityClass(labels []int) int {
counts := make(map[int]int)
for _, label := range labels {
counts[label]++
}
maxCount := 0
majority := 0
for class, count := range counts {
if count > maxCount {
maxCount = count
majority = class
}
}
return majority
}
// 获取唯一值
func getUniqueValues(features [][]float64, idx int) []float64 {
unique := make(map[float64]bool)
for _, sample := range features {
unique[sample[idx]] = true
}
var values []float64
for v := range unique {
values = append(values, v)
}
return values
}
// 分割数据
func splitData(features [][]float64, labels []int, splitIdx int, splitVal float64) (
leftFeat [][]float64, leftLab []int, rightFeat [][]float64, rightLab []int) {
for i, sample := range features {
if sample[splitIdx] <= splitVal {
leftFeat = append(leftFeat, sample)
leftLab = append(leftLab, labels[i])
} else {
rightFeat = append(rightFeat, sample)
rightLab = append(rightLab, labels[i])
}
}
return
}
-
示例使用
func main() { // 极简训练数据 features := [][]float64{ {2.0, 1.0}, // 类别0 {2.0, 2.0}, // 类别0 {1.0, 2.0}, // 类别1 {1.0, 1.0}, // 类别1 } labels := []int{0, 0, 1, 1} // 创建并训练树 tree := SimpleCart{} tree.Fit(features, labels) // 测试预测 testCases := []struct { features []float64 expected int }{ {[]float64{1.5, 1.5}, 1}, // 接近决策边界 {[]float64{2.0, 1.5}, 0}, // 明确类别0 {[]float64{1.0, 1.5}, 1}, // 明确类别1 } for _, tc := range testCases { predicted := tree.Predict(tc.features) fmt.Printf("特征: %v, 预测: %d, 期望: %d\n", tc.features, predicted, tc.expected) } }
3 小结
常见问题
- 决策树学习的主要问题是什么?
决策树学习中的主要问题包括过拟合、对小数据变化的敏感性和有限的泛化。确保正确修剪、调整和处理不平衡的数据有助于缓解这些挑战,从而获得更强大的决策树模型。
- 决策树如何帮助决策?
决策树通过在分层结构中表示复杂的选择来帮助决策。每个节点都测试特定属性,根据数据值指导决策。叶节点提供最终结果,为机器学习中的决策分析提供了清晰且可解释的路径。
- 决策树的最大深度是多少?
决策树的最大深度是一个超参数,用于确定从根到任何叶的最大级别或节点数。它控制树的复杂度,并有助于防止过度拟合。
- 决策树的概念是什么?
决策树是一种监督式学习算法,它根据输入特征对决策进行建模。它形成一个树状结构,其中每个内部节点代表一个基于属性的决策,导致叶节点代表结果。
- 什么是决策树中的熵?
在决策树中,熵是数据集中杂质或无序的度量。它量化了与实例分类相关的不确定性,指导算法进行信息性拆分,以便做出有效的决策。
-
决策树的超参数是什么?
Max Depth(最大深度):树的最大深度。 Min Samples Split(最小样本拆分):拆分内部节点所需的最小样本数。 Min Samples Leaf:叶节点所需的最小样本数。 Criterion:用于衡量分割质量的函数.
- 点赞
- 收藏
- 关注作者
评论(0)