ID3算法与C4.5算法的实现
学习笔记|决策树的特征选择、学习笔记|决策树的生成、学习笔记决策树的剪枝分别介绍了决策树算法的特征选择、树生成与剪枝,三者结合者是完整的决策树算法。
1. 信息增益(比)的计算
信息增益和信息增益比分别是ID3算法和C4.5算法特征选择的方法。
首先,计算信息熵。
def count_entropy(y):
e = dict(Counter(y))
return np.sum([- v / len(y) * np.log2(v / len(y)) for k, v in e.items()])
其次,计算条件熵。
def count_conditional_entropy(x, y):
ce = 0
for xx in np.unique(x):
yx = y[np.where(x == xx)]
ce += count_entropy(yx) * len(yx) / len(y)
return ce
再次,计算信息增益。
def count_information_gain(x, y):
return count_entropy(y) - count_conditional_entropy(x, y)
第四,计算信息增益比。
def count_information_gain_ratio(x, y):
return count_information_gain(x, y) / count_entropy(x)
最后,验证信息增益的计算。
if __name__ == "__main__":
y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
print(count_information_gain(x[:, 0], y))
可以得到信息增益为0.08300749985576883,与参考文献1中的案例结果一致。
2. 决策树的生成
决策树的生成代码如下:
def generate_tree(x, y, A, algrithm='id3', epsilon=0):
algrithm_function = count_information_gain if algrithm == 'id3' else count_information_gain_ratio
if count_entropy(y) == 0:
return (y[0], x, y)
c = get_max_class(y)
if not len(A):
return (c, x, y)
g = [algrithm_function(x[:, i], y) for i in range(x.shape[1])]
if np.max(g) < epsilon:
return (c, x, y)
ai = np.argmax(g)
A.pop(ai)
t = {}
for a in np.unique(x[:, ai]):
si = np.where(x[:, ai] == a)
t[a] = generate_tree(x[si], y[si], A, algrithm=algrithm, epsilon=epsilon)
return t
当algrithm为id3时,algrithm_function为信息增益计算函数,否则是信息增益比的计算函数。当返回的是叶结点时,用三维元组来表示一个叶结点,x和y共同构成了这个叶结点下的样本,而c是样本数最多的分类,通过以下函数实现。
def get_max_class(y):
d = dict(Counter(y))
return list(d.keys())[np.argmax([v for v in d.values()])]
通过以下代码可以简单验证决策树生成效果。
if __name__ == "__main__":
y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
x = np.array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '有工作', '无房子', '好'], ['青年', '有工作', '有房子', '一般'],
['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['中年', '有工作', '有房子', '好'],
['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好'],
['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好'], ['老年', '无工作', '无房子', '一般']])
T = generate_tree(x, y, list(range(x.shape[1])))
print(T)
效果如下:
{'无房子': {'无工作': ('n', array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['老年', '无工作', '无房子', '一般']], dtype='<U3'), array(['n', 'n', 'n', 'n', 'n', 'n'], dtype='<U1')), '有工作': ('y', array([['青年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好']], dtype='<U3'), array(['y', 'y', 'y'], dtype='<U1'))}, '有房子': ('y', array([['青年', '有工作', '有房子', '一般'], ['中年', '有工作', '有房子', '好'], ['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好']], dtype='<U3'), array(['y', 'y', 'y', 'y', 'y', 'y'], dtype='<U1'))}
3. 剪枝
通过以下函数实现剪枝。
def cut_leaf(T, alpha):
uncut_entropy = 0
supper_y = []
supper_x = []
son_num = 0
for k, v in T.items():
if isinstance(v, dict):
T[k] = cut_leaf(v, alpha=alpha)
for k, v in T.items():
if isinstance(v, tuple):
uncut_entropy += len(v[2]) * count_entropy(v[2])
supper_y += list(v[2])
supper_x += list(v[1])
son_num += 1
else:
return T
cut_entropy = len(supper_y) * count_entropy(np.array(supper_y))
if cut_entropy < uncut_entropy + son_num * alpha:
print(supper_y)
return (get_max_class(np.array(supper_y)), np.array(supper_x), np.array(supper_y))
return T
其中,count_T是计算决策树T的结点数,由于存储格式的原因,这里计算的节点数需要+1。
def count_T(T):
c = len(T.keys())
for k, v in T.items():
if isinstance(v, dict):
c += count_T(v)
return c
参考文献
【1】统计学习方法(第2版),李航著,清华大学出版社
- 点赞
- 收藏
- 关注作者
评论(0)