学习笔记|CART算法的实现
1. 生成最小二乘回归树的算法实现
首先,计算(j,s)确定情况下的损失:
def cal_regression_loss(x, y, j, s):
r1 = np.where(x[:, j] <= s)
r2 = np.where(x[:, j] > s)
c1 = get_max_class(y[r1])
c2 = get_max_class(y[r2])
return np.sum((y[r1] - c1) * (y[r1] - c1)) + np.sum((y[r2] - c2) * (y[r2] - c2)), r1, r2
(见学习笔记|CART算法)的样本序号;get_max_class函数见ID3算法与C4.5算法的实现。
然后,生成最小二乘回归树:
def generate_tree(x, y):
if not len(y):
return None
cut = []
for j in range(x.shape[1]):
unique_x = np.unique(x[:, j])
if len(unique_x) > 1:
for s in unique_x:
l, r1, r2 = loss_function(x, y, j, s)
if not len(cut):
cut = [j, s, l, r1, r2]
elif cut[2] > l:
cut = [j, s, l, r1, r2]
if len(cut):
t = binary_tree({'j': cut[0], 's': cut[1], 'x': x, 'y': y, 'c': get_max_class(y)})
t.left_child = generate_tree(x[cut[3]], y[cut[3]])
t.right_child = generate_tree(x[cut[4]], y[cut[4]])
return t
return binary_tree({'x': x, 'y': y, 'c': get_max_class(y)})
如果样本数量为0,则返回None;对(j,s)进行嵌套循环,对每一对(j,s)计算最小二乘损失,找到使损失最小的(j,s),初始化二叉树后(二叉树类binary_tree见学习笔记|k近邻法的实现),根据x的第j列将样本分成两部分,依次划分为左子树和右子树并递归地进行树的生成;如果不存在(j,s),则返回当前的样本。
仍使用ID3算法与C4.5算法的实现中的例子,将原有分类用数字表示后,可以尝试如下:
if __name__ == "__main__":
y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
x = np.array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '有工作', '无房子', '好'], ['青年', '有工作', '有房子', '一般'],
['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['中年', '有工作', '有房子', '好'],
['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好'],
['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好'], ['老年', '无工作', '无房子', '一般']])
x[:, 0] = (x[:, 0] == '老年') + 1 + (x[:, 0] == '青年') + (x[:, 0] == '老年')
x[:, 1] = (x[:, 1] == '无工作') + 1
x[:, 2] = (x[:, 2] == '无房子') + 1
x[:, 3] = (x[:, 3] == '一般') + 1 + (x[:, 3] == '好') + (x[:, 3] == '一般')
y = (y == 'y') + 0
t = generate_tree(x, y)
t.tree_print()
可以得到结果如下:
0 {'j': 2, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '1', '2', '2'], ['2', '1', '1', '3'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1'], ['3', '2', '1', '1'], ['3', '2', '1', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]), 'c': 1}
1 {'j': 0, 's': '1', 'x': array([['2', '1', '1', '3'], ['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1'], ['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1, 1, 1, 1, 1]), 'c': 1}
2 {'j': 1, 's': '1', 'x': array([['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['1', '1', '1', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'x': array([['1', '2', '1', '1'], ['1', '2', '1', '1']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
2 {'j': 0, 's': '2', 'x': array([['2', '1', '1', '3'], ['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['2', '1', '1', '3']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'j': 3, 's': '1', 'x': array([['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
4 {'x': array([['3', '2', '1', '1']], dtype='<U3'), 'y': array([1]), 'c': 1}
4 {'x': array([['3', '2', '1', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
1 {'j': 1, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '1', '2', '2'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 1, 0, 0, 0, 1, 1, 0]), 'c': 0}
2 {'j': 0, 's': '2', 'x': array([['2', '1', '2', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['2', '1', '2', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'j': 3, 's': '1', 'x': array([['3', '1', '2', '2'], ['3', '1', '2', '1']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
4 {'x': array([['3', '1', '2', '1']], dtype='<U3'), 'y': array([1]), 'c': 1}
4 {'x': array([['3', '1', '2', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
2 {'j': 0, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0, 0, 0, 0]), 'c': 0}
3 {'j': 3, 's': '2', 'x': array([['1', '2', '2', '3'], ['1', '2', '2', '2']], dtype='<U3'), 'y': array([0, 0]), 'c': 0}
4 {'x': array([['1', '2', '2', '2']], dtype='<U3'), 'y': array([0]), 'c': 0}
4 {'x': array([['1', '2', '2', '3']], dtype='<U3'), 'y': array([0]), 'c': 0}
3 {'j': 0, 's': '2', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0, 0]), 'c': 0}
4 {'j': 3, 's': '2', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0]), 'c': 0}
5 {'x': array([['2', '2', '2', '2']], dtype='<U3'), 'y': array([0]), 'c': 0}
5 {'x': array([['2', '2', '2', '3'], ['2', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0]), 'c': 0}
4 {'x': array([['3', '2', '2', '3']], dtype='<U3'), 'y': array([0]), 'c': 0}其中,为了更方便地实现二叉树的打印,可在二叉树类(见学习笔记|k近邻法的实现)中加上树的打印:
def tree_print(self, level=0):
print(level, self.key)
level += 1
if isinstance(self.left_child, binary_tree):
self.left_child.tree_print(level)
if isinstance(self.right_child, binary_tree):
self.right_child.tree_print(level)
2. CART树(分类树)生成算法的实现
首先,计算变量的Gini指数:
def cal_gini(y):
e = dict(Counter(y))
return 1 - np.sum([(v / len(y)) ** 2 for k, v in e.items()])
其次,计算划分后的Gini指数:
def cal_classification_loss(x, y, j, a):
r1 = np.where(x[:, j] == a)
r2 = np.where(x[:, j] != a)
return len(y[r1]) / len(y) * cal_gini(y[r1]) + len(y[r2]) / len(y) * cal_gini(y[r2]), r1, r2
最后,生成CART树只需要将生成最小二乘回归树中的损失损失函数cal_regression_loss改成以上计算划分后的Gini指数的函数cal_classification_loss即可,可以通过参数的方式,简单改造原决策树生成算法来实现:
def generate_tree(x, y, loss_mode='regression'):
loss_function = cal_regression_loss if loss_mode == 'regression' else cal_classification_loss
if not len(y):
return None
cut = []
for j in range(x.shape[1]):
unique_x = np.unique(x[:, j])
if len(unique_x) > 1:
for s in unique_x:
l, r1, r2 = loss_function(x, y, j, s)
if not len(cut):
cut = [j, s, l, r1, r2]
elif cut[2] > l:
cut = [j, s, l, r1, r2]
if len(cut):
t = binary_tree({'j': cut[0], 's': cut[1], 'x': x, 'y': y, 'c': get_max_class(y)})
t.left_child = generate_tree(x[cut[3]], y[cut[3]], loss_mode=loss_mode)
t.right_child = generate_tree(x[cut[4]], y[cut[4]], loss_mode=loss_mode)
return t
return binary_tree({'x': x, 'y': y, 'c': get_max_class(y)})
同样可以用上述例子进行简单的尝试:
if __name__ == "__main__":
y = np.array(['n', 'n', 'y', 'y', 'n', 'n', 'n', 'y', 'y', 'y', 'y', 'y', 'y', 'y', 'n'])
x = np.array([['青年', '无工作', '无房子', '一般'], ['青年', '无工作', '无房子', '好'], ['青年', '有工作', '无房子', '好'], ['青年', '有工作', '有房子', '一般'],
['青年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '一般'], ['中年', '无工作', '无房子', '好'], ['中年', '有工作', '有房子', '好'],
['中年', '无工作', '有房子', '非常好'], ['中年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '非常好'], ['老年', '无工作', '有房子', '好'],
['老年', '有工作', '无房子', '好'], ['老年', '有工作', '无房子', '非常好'], ['老年', '无工作', '无房子', '一般']])
x[:, 0] = (x[:, 0] == '老年') + 1 + (x[:, 0] == '青年') + (x[:, 0] == '老年')
x[:, 1] = (x[:, 1] == '无工作') + 1
x[:, 2] = (x[:, 2] == '无房子') + 1
x[:, 3] = (x[:, 3] == '一般') + 1 + (x[:, 3] == '好') + (x[:, 3] == '一般')
y = (y == 'y') + 0
t = generate_tree(x, y, loss_mode='classification')
t.tree_print()
得到结果如下:
0 {'j': 2, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '1', '2', '2'], ['2', '1', '1', '3'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1'], ['3', '2', '1', '1'], ['3', '2', '1', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0]), 'c': 1}
1 {'j': 0, 's': '1', 'x': array([['2', '1', '1', '3'], ['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1'], ['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1, 1, 1, 1, 1]), 'c': 1}
2 {'j': 1, 's': '1', 'x': array([['1', '1', '1', '2'], ['1', '2', '1', '1'], ['1', '2', '1', '1']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['1', '1', '1', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'x': array([['1', '2', '1', '1'], ['1', '2', '1', '1']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
2 {'j': 0, 's': '2', 'x': array([['2', '1', '1', '3'], ['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['2', '1', '1', '3']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'j': 3, 's': '1', 'x': array([['3', '2', '1', '1'], ['3', '2', '1', '2']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
4 {'x': array([['3', '2', '1', '1']], dtype='<U3'), 'y': array([1]), 'c': 1}
4 {'x': array([['3', '2', '1', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
1 {'j': 1, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '1', '2', '2'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 1, 0, 0, 0, 1, 1, 0]), 'c': 0}
2 {'j': 0, 's': '2', 'x': array([['2', '1', '2', '2'], ['3', '1', '2', '2'], ['3', '1', '2', '1']], dtype='<U3'), 'y': array([1, 1, 1]), 'c': 1}
3 {'x': array([['2', '1', '2', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
3 {'j': 3, 's': '1', 'x': array([['3', '1', '2', '2'], ['3', '1', '2', '1']], dtype='<U3'), 'y': array([1, 1]), 'c': 1}
4 {'x': array([['3', '1', '2', '1']], dtype='<U3'), 'y': array([1]), 'c': 1}
4 {'x': array([['3', '1', '2', '2']], dtype='<U3'), 'y': array([1]), 'c': 1}
2 {'j': 0, 's': '1', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3'], ['1', '2', '2', '3'], ['1', '2', '2', '2'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0, 0, 0, 0]), 'c': 0}
3 {'j': 3, 's': '2', 'x': array([['1', '2', '2', '3'], ['1', '2', '2', '2']], dtype='<U3'), 'y': array([0, 0]), 'c': 0}
4 {'x': array([['1', '2', '2', '2']], dtype='<U3'), 'y': array([0]), 'c': 0}
4 {'x': array([['1', '2', '2', '3']], dtype='<U3'), 'y': array([0]), 'c': 0}
3 {'j': 0, 's': '2', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3'], ['3', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0, 0]), 'c': 0}
4 {'j': 3, 's': '2', 'x': array([['2', '2', '2', '3'], ['2', '2', '2', '2'], ['2', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0, 0]), 'c': 0}
5 {'x': array([['2', '2', '2', '2']], dtype='<U3'), 'y': array([0]), 'c': 0}
5 {'x': array([['2', '2', '2', '3'], ['2', '2', '2', '3']], dtype='<U3'), 'y': array([0, 0]), 'c': 0}
4 {'x': array([['3', '2', '2', '3']], dtype='<U3'), 'y': array([0]), 'c': 0}
参考文献
【1】统计学习方法(第2版),李航著,清华大学出版社
- 点赞
- 收藏
- 关注作者
评论(0)