在CART决策树中使用分类功能

举报
码乐 发表于 2025/06/15 08:01:03 2025/06/15
【摘要】 1 简介决策树中的 CART 分类回归树CART( Classification And Regression Trees 分类与回归树)模型是一种用于分类和回归任务的决策树算法,是决策树算法的一种变体。它可以处理分类和回归任务。 Scikit-Learn 使用分类与回归树 (CART) 算法来训练决策树(也称为“增长”树)。CART 于 1984 年由 Leo Breiman、Jerom...

1 简介

决策树中的 CART 分类回归树

CART( Classification And Regression Trees 分类与回归树)模型是一种用于分类和回归任务的决策树算法,是决策树算法的一种变体。
它可以处理分类和回归任务。 Scikit-Learn 使用分类与回归树 (CART) 算法来训练决策树(也称为“增长”树)。

CART 于 1984 年由 Leo Breiman、Jerome Friedman、Richard Olshen 和 Charles Stone 首次制作。
它由Breiman等人于1986年提出,广泛应用于数据挖掘和机器学习中。

2 归类树实现示例

包括以下功能:

分类功能

移除所有剪枝参数和复杂配置

使用最简单的基尼不纯度作为分裂标准

核心保留:

完整的树构建过程

最佳分裂点寻找

预测功能

代码精简:

约100行核心代码

直观的数据结构

清晰的递归构建过程

示例数据:

使用最简单的2D分类数据

包含3个测试用例

      package main

      import (
          "fmt"
      )

      // 树节点结构
      type TreeNode struct {
          IsLeaf     bool
          Class      int       // 叶子节点存储的类别
          SplitIndex int       // 分裂特征索引
          SplitValue float64   // 分裂值
          Left       *TreeNode // 左子树
          Right      *TreeNode // 右子树
      }

      // 极简CART分类树
      type SimpleCart struct {
          Root *TreeNode
      }

      // 训练入口
      func (c *SimpleCart) Fit(features [][]float64, labels []int) {
          c.Root = c.buildTree(features, labels)
      }

      // 递归构建树
      func (c *SimpleCart) buildTree(features [][]float64, labels []int) *TreeNode {
          // 如果所有标签相同,创建叶子节点
          if allSame(labels) {
              return &TreeNode{IsLeaf: true, Class: labels[0]}
          }

          // 寻找最佳分裂
          bestIdx, bestVal := c.findBestSplit(features, labels)
          if bestIdx == -1 {
              return &TreeNode{IsLeaf: true, Class: majorityClass(labels)}
          }

          // 分割数据
          leftFeat, leftLab, rightFeat, rightLab := splitData(features, labels, bestIdx, bestVal)

          // 递归构建子树
          node := &TreeNode{
              IsLeaf:     false,
              SplitIndex: bestIdx,
              SplitValue: bestVal,
              Left:       c.buildTree(leftFeat, leftLab),
              Right:      c.buildTree(rightFeat, rightLab),
          }
          return node
      }

      // 寻找最佳分裂
      func (c *SimpleCart) findBestSplit(features [][]float64, labels []int) (int, float64) {
          bestGini := 1.0
          bestIdx := -1
          bestVal := 0.0

          for i := 0; i < len(features[0]); i++ {
              // 获取当前特征的所有唯一值
              values := getUniqueValues(features, i)
              for _, v := range values {
                  // 计算基尼指数
                  gini := c.calculateGini(features, labels, i, v)
                  if gini < bestGini {
                      bestGini = gini
                      bestIdx = i
                      bestVal = v
                  }
              }
          }
          return bestIdx, bestVal
      }

      // 计算基尼指数
      func (c *SimpleCart) calculateGini(features [][]float64, labels []int, splitIdx int, splitVal float64) float64 {
          leftCounts := make(map[int]int)
          rightCounts := make(map[int]int)
          totalLeft := 0
          totalRight := 0

          for i, sample := range features {
              if sample[splitIdx] <= splitVal {
                  leftCounts[labels[i]]++
                  totalLeft++
              } else {
                  rightCounts[labels[i]]++
                  totalRight++
              }
          }

          // 计算左右子树的基尼不纯度
          giniLeft := 1.0
          for _, count := range leftCounts {
              p := float64(count) / float64(totalLeft)
              giniLeft -= p * p
          }

          giniRight := 1.0
          for _, count := range rightCounts {
              p := float64(count) / float64(totalRight)
              giniRight -= p * p
          }

          // 加权平均
          total := float64(totalLeft + totalRight)
          return (float64(totalLeft)/total)*giniLeft + (float64(totalRight)/total)*giniRight
      }

      // 预测函数
      func (c *SimpleCart) Predict(feature []float64) int {
          node := c.Root
          for !node.IsLeaf {
              if feature[node.SplitIndex] <= node.SplitValue {
                  node = node.Left
              } else {
                  node = node.Right
              }
          }
          return node.Class
      }

      // --- 辅助函数 ---

      // 检查所有标签是否相同
      func allSame(labels []int) bool {
          if len(labels) == 0 {
              return true
          }
          first := labels[0]
          for _, label := range labels {
              if label != first {
                  return false
              }
          }
          return true
      }

      // 获取多数类别
      func majorityClass(labels []int) int {
          counts := make(map[int]int)
          for _, label := range labels {
              counts[label]++
          }
          maxCount := 0
          majority := 0
          for class, count := range counts {
              if count > maxCount {
                  maxCount = count
                  majority = class
              }
          }
          return majority
      }

      // 获取唯一值
      func getUniqueValues(features [][]float64, idx int) []float64 {
          unique := make(map[float64]bool)
          for _, sample := range features {
              unique[sample[idx]] = true
          }
          var values []float64
          for v := range unique {
              values = append(values, v)
          }
          return values
      }

      // 分割数据
      func splitData(features [][]float64, labels []int, splitIdx int, splitVal float64) (
          leftFeat [][]float64, leftLab []int, rightFeat [][]float64, rightLab []int) {

          for i, sample := range features {
              if sample[splitIdx] <= splitVal {
                  leftFeat = append(leftFeat, sample)
                  leftLab = append(leftLab, labels[i])
              } else {
                  rightFeat = append(rightFeat, sample)
                  rightLab = append(rightLab, labels[i])
              }
          }
          return
      }
  • 示例使用

       func main() {
           // 极简训练数据
           features := [][]float64{
               {2.0, 1.0}, // 类别0
               {2.0, 2.0}, // 类别0
               {1.0, 2.0}, // 类别1
               {1.0, 1.0}, // 类别1
           }
           labels := []int{0, 0, 1, 1}
    
           // 创建并训练树
           tree := SimpleCart{}
           tree.Fit(features, labels)
    
           // 测试预测
           testCases := []struct {
               features []float64
               expected int
           }{
               {[]float64{1.5, 1.5}, 1}, // 接近决策边界
               {[]float64{2.0, 1.5}, 0}, // 明确类别0
               {[]float64{1.0, 1.5}, 1}, // 明确类别1
           }
    
           for _, tc := range testCases {
               predicted := tree.Predict(tc.features)
               fmt.Printf("特征: %v, 预测: %d, 期望: %d\n", tc.features, predicted, tc.expected)
           }
       }
    

3 小结

常见问题

  1. 决策树学习的主要问题是什么?

决策树学习中的主要问题包括过拟合、对小数据变化的敏感性和有限的泛化。确保正确修剪、调整和处理不平衡的数据有助于缓解这些挑战,从而获得更强大的决策树模型。

  1. 决策树如何帮助决策?

决策树通过在分层结构中表示复杂的选择来帮助决策。每个节点都测试特定属性,根据数据值指导决策。叶节点提供最终结果,为机器学习中的决策分析提供了清晰且可解释的路径。

  1. 决策树的最大深度是多少?

决策树的最大深度是一个超参数,用于确定从根到任何叶的最大级别或节点数。它控制树的复杂度,并有助于防止过度拟合。

  1. 决策树的概念是什么?

决策树是一种监督式学习算法,它根据输入特征对决策进行建模。它形成一个树状结构,其中每个内部节点代表一个基于属性的决策,导致叶节点代表结果。

  1. 什么是决策树中的熵?

在决策树中,熵是数据集中杂质或无序的度量。它量化了与实例分类相关的不确定性,指导算法进行信息性拆分,以便做出有效的决策。

  1. 决策树的超参数是什么?

     	Max Depth(最大深度):树的最大深度。
     	Min Samples Split(最小样本拆分):拆分内部节点所需的最小样本数。
     	Min Samples Leaf:叶节点所需的最小样本数。
     	Criterion:用于衡量分割质量的函数.
    
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。