ID3算法与C4.5算法的实现

举报
darkpard 发表于 2021/10/29 20:50:14 2021/10/29
【摘要】 学习笔记|决策树的特征选择、学习笔记|决策树的生成、学习笔记决策树的剪枝分别介绍了决策树算法的特征选择、树生成与剪枝,三者结合者是完整的决策树算法。1. 信息增益(比)的计算信息增益和信息增益比分别是ID3算法和C4.5算法特征选择的方法。首先,计算信息熵。def count_entropy(y): e = dict(Counter(y)) return np.sum([- v ...

学习笔记|决策树的特征选择学习笔记|决策树的生成学习笔记决策树的剪枝分别介绍了决策树算法的特征选择、树生成与剪枝,三者结合者是完整的决策树算法。

1. 信息增益(比)的计算

信息增益和信息增益比分别是ID3算法和C4.5算法特征选择的方法。

首先,计算信息熵。

def count_entropy(y):
    e = dict(Counter(y))
    return np.sum([- v / len(y) * np.log2(v / len(y)) for k, v in e.items()])

其次,计算条件熵。

def count_conditional_entropy(x, y):
    ce = 0
    for xx in np.unique(x):
        yx = y[np.where(x == xx)]
        ce += count_entropy(yx) * len(yx) / len(y)
    return ce

再次,计算信息增益。

def count_information_gain(x, y):
    return count_entropy(y) - count_conditional_entropy(x, y)

第四,计算信息增益比。

def count_information_gain_ratio(x, y):
    return count_information_gain(x, y) / count_entropy(x)

最后,验证信息增益的计算。

if __name__ == "__main__":
    y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
    print(count_information_gain(x[:, 0], y))

可以得到信息增益为0.08300749985576883,与参考文献1中的案例结果一致。

2. 决策树的生成

决策树的生成代码如下:

def generate_tree(x, y, A, algrithm='id3', epsilon=0):
    algrithm_function = count_information_gain if algrithm == 'id3' else count_information_gain_ratio
    if count_entropy(y) == 0:
        return (y[0], x, y)
    c = get_max_class(y)
    if not len(A):
        return (c, x, y)
    g = [algrithm_function(x[:, i], y) for i in range(x.shape[1])]
    if np.max(g) < epsilon:
        return (c, x, y)
    ai = np.argmax(g)
    A.pop(ai)
    t = {}
    for a in np.unique(x[:, ai]):
        si = np.where(x[:, ai] == a)
        t[a] = generate_tree(x[si], y[si], A, algrithm=algrithm, epsilon=epsilon)
    return t

当algrithm为id3时,algrithm_function为信息增益计算函数,否则是信息增益比的计算函数。当返回的是叶结点时,用三维元组来表示一个叶结点,x和y共同构成了这个叶结点下的样本,而c是样本数最多的分类,通过以下函数实现。

def get_max_class(y):
    d = dict(Counter(y))
    return list(d.keys())[np.argmax([v for v in d.values()])]

通过以下代码可以简单验证决策树生成效果。

if __name__ == "__main__":
    y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
    x = np.array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '有工作', '无房子', '好'], ['青年', '有工作', '有房子', '一般'],
                  ['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['中年', '有工作', '有房子', '好'],
                  ['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好'],
                  ['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好'], ['老年', '无工作', '无房子', '一般']])
    T = generate_tree(x, y, list(range(x.shape[1])))
    print(T)

效果如下:

{'无房子': {'无工作': ('n', array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['老年', '无工作', '无房子', '一般']], dtype='<U3'), array(['n', 'n', 'n', 'n', 'n', 'n'], dtype='<U1')), '有工作': ('y', array([['青年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好']], dtype='<U3'), array(['y', 'y', 'y'], dtype='<U1'))}, '有房子': ('y', array([['青年', '有工作', '有房子', '一般'], ['中年', '有工作', '有房子', '好'], ['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好']], dtype='<U3'), array(['y', 'y', 'y', 'y', 'y', 'y'], dtype='<U1'))}

3. 剪枝

通过以下函数实现剪枝。

def cut_leaf(T, alpha):
    uncut_entropy = 0
    supper_y = []
    supper_x = []
    son_num = 0
    for k, v in T.items():
        if isinstance(v, dict):
            T[k] = cut_leaf(v, alpha=alpha)
    for k, v in T.items():
        if isinstance(v, tuple):
            uncut_entropy += len(v[2]) * count_entropy(v[2])
            supper_y += list(v[2])
            supper_x += list(v[1])
            son_num += 1
        else:
            return T
    cut_entropy = len(supper_y) * count_entropy(np.array(supper_y))
    if cut_entropy < uncut_entropy + son_num * alpha:
        print(supper_y)
        return (get_max_class(np.array(supper_y)), np.array(supper_x), np.array(supper_y))
    return T

其中,count_T是计算决策树T的结点数,由于存储格式的原因,这里计算的节点数需要+1。

def count_T(T):
    c = len(T.keys())
    for k, v in T.items():
        if isinstance(v, dict):
            c += count_T(v)
    return c

参考文献

【1】统计学习方法(第2版),李航著,清华大学出版社

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。