学习笔记|k近邻法的实现

举报
darkpard 发表于 2021/10/23 09:31:46 2021/10/23
【摘要】 学习笔记|k近邻分类算法 指出k近邻分类算法通过kd树的构造和搜索来实现。1. 构建二叉树类为了实现kd树的构造和搜索算法,我们先构建一个二叉树类。首先,申明类,初始化根结点和左、右子结点。class binary_tree(object): def __init__(self, root_obj): self.key = root_obj self.lef...

学习笔记|k近邻分类算法 指出k近邻分类算法通过kd树的构造和搜索来实现。

1. 构建二叉树类

为了实现kd树的构造和搜索算法,我们先构建一个二叉树类。首先,申明类,初始化根结点和左、右子结点。

class binary_tree(object):
    def __init__(self, root_obj):
        self.key = root_obj
        self.left_child = None
        self.right_child = None

其次,构造插入左、右子树方法。

    def insert_left_child(self, new_obj):
        new_tree = binary_tree(new_obj)
        if self.left_child == None:
            self.left_child = new_tree
        else:
            new_tree.left_child = self.left_child
            self.left_child = new_tree

    def insert_right_child(self, new_obj):
        new_tree = binary_tree(new_obj)
        if self.right_child == None:
            self.right_child = new_tree
        else:
            new_tree.right_child = self.right_child
            self.right_child = new_tree

再次,构造读、写根节点方法。

    def get_root_value(self):
        return self.key

    def set_root_value(self, root_obj):
        self.key = root_obj

最后,构造读取左、右子树方法。

    def get_left_child(self):
        return self.left_child

    def get_right_child(self):
        return self.right_child

2. 构造kd树

在二叉树的基础上构造kd树,事实上kd树只需要实现二叉树的部分功能。

class kd_tree(binary_tree):
    def set_child(self, new_obj, lr='l'):
        if lr == 'l':
            self.left_child = new_obj
        else:
            self.right_child = new_obj

然后生成kd树。

def generate_kd_tree(x, d=0):
    if len(x) > 1:
        x = x[np.argsort(-x[:,d])]
        mi = np.argmin(np.abs(x[:, d] - np.median(x[:, d])))
        kd_tree1 = kd_tree(x[mi])
        d += 1
        if len(x[:mi]):
            kd_tree1.set_child(generate_kd_tree(x[:mi], d), lr='r')
        if len(x[mi+1:]):
            kd_tree1.set_child(generate_kd_tree(x[mi+1:], d), lr='l')
        return kd_tree1
    return kd_tree(x[0])

根据书本(参考文献2)上的案例对生成的kd树进行简单验证。

if __name__ == "__main__":
    kd_tree = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
    print(kd_tree.key, kd_tree.left_child.key, kd_tree.right_child.key, kd_tree.left_child.left_child.key, kd_tree.left_child.right_child.key, kd_tree.right_child.left_child.key)

得到:
[7 2] [5 4] [9 6] [2 3] [4 7] [8 1]

3. kd树搜索

kd树的搜索可以通过递归的方法来实现。当然,有点偷懒,这里的代码比较冗长。

def search_kd_tree(kd_tree1, s, d=0):
    if kd_tree1.left_child == None:
        if kd_tree1.right_child == None:
            return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
        if np.linalg.norm(kd_tree1.right_child.key - s) < np.linalg.norm(kd_tree1.key - s):
            return {'p': kd_tree1.right_child.key, 'r': np.linalg.norm(kd_tree1.right_child.key - s)}
        return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
    elif kd_tree1.right_child == None:
        if np.linalg.norm(kd_tree1.left_child.key - s) < np.linalg.norm(kd_tree1.key - s):
            return {'p': kd_tree1.left_child.key, 'r': np.linalg.norm(kd_tree1.left_child.key - s)}
        return {'p': kd_tree1.key, 'r': np.linalg.norm(kd_tree1.key - s)}
    if s[d] < kd_tree1.key[d]:
        d = (d + 1) % len(s)
        c = search_kd_tree(kd_tree1.left_child, s, d=d)
        lr = 'l'
    else:
        d = (d + 1) % len(s)
        c = search_kd_tree(kd_tree1.right_child, s, d=d)
        lr = 'r'
    if np.linalg.norm(kd_tree1.key - s) < c['r']:
        c['p'] = kd_tree1.key
        c['r'] = np.linalg.norm(c['p'] - 1)
    s1 = s.copy()
    s1[d] = kd_tree1.key[d]
    if np.linalg.norm(s1 - s) < c['r']:
        if lr == 'l':
            c1 = search_kd_tree(kd_tree1.right_child, s, d=d)
        else:
            c1 = search_kd_tree(kd_tree1.left_child, s, d=d)
        if c1['r'] < c['r']:
            c = c1
    return c

4. kd树搜索举例

if __name__ == "__main__":
    kd_tree1 = generate_kd_tree(np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]))
    print(kd_tree1.key, kd_tree1.left_child.key, kd_tree1.right_child.key, kd_tree1.left_child.left_child.key, kd_tree1.left_child.right_child.key, kd_tree1.right_child.left_child.key)
    print(search_kd_tree(kd_tree1, np.array([2, 4.5])))
    print(search_kd_tree(kd_tree1, np.array([9, 7])))

结合前面的例子,做个简单的验证,可以看到当s=(2,4.5)时,最近邻是点(2,3),距离是1.5;当s=(9,7)时,最近邻是点(9,6),距离是1.0。其中,s=(9,7)是具有一定的特殊性的,这里不再赘述。
[7 2] [5 4] [9 6] [2 3] [4 7] [8 1]
{'p': array([2, 3]), 'r': 1.5}
{'p': array([9, 6]), 'r': 1.0}

参考文献

[1]https://blog.csdn.net/m0_37324740/article/details/79435814
[2]统计学习方法(第2版),李航著,清华大学出版社 [3]https://zhuanlan.zhihu.com/p/104758420

相关链接:

  1. 学习笔记|k近邻分类算法
  2. 学习笔记|感知机的实现
  3. 学习笔记|朴素贝叶斯法
  4. 学习笔记|决策树模型及其学习
  5. 学习笔记|感知机(二)
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。