监督学习算法中K近邻算法(K-Nearest Neighbors)

举报
皮牙子抓饭 发表于 2023/08/27 17:49:18 2023/08/27
【摘要】 K近邻算法(K-Nearest Neighbors, KNN)是一种常用的监督学习算法,用于分类和回归问题。该算法基于一个简单的假设:相似的样本具有相似的标签。KNN算法通过计算样本之间的距离来确定最近的K个邻居,并根据这些邻居的标签进行预测。 KNN算法的工作流程如下:计算距离:对于一个给定的测试样本,计算它与训练集中所有样本之间的距离。常用的距离度量方法有欧氏距离、曼哈顿距离等。选择K个...

K近邻算法(K-Nearest Neighbors, KNN)是一种常用的监督学习算法,用于分类和回归问题。该算法基于一个简单的假设:相似的样本具有相似的标签。KNN算法通过计算样本之间的距离来确定最近的K个邻居,并根据这些邻居的标签进行预测。 KNN算法的工作流程如下:

  1. 计算距离:对于一个给定的测试样本,计算它与训练集中所有样本之间的距离。常用的距离度量方法有欧氏距离、曼哈顿距离等。
  2. 选择K个邻居:根据计算得到的距离,选择与测试样本最近的K个训练样本作为邻居。
  3. 进行投票或计算平均值:对于分类问题,K个邻居中出现最频繁的类别作为测试样本的预测标签;对于回归问题,K个邻居的标签的平均值作为测试样本的预测值。 KNN算法的优点包括:
  • 简单易理解,无需假设数据分布。
  • 对异常值和噪声具有较好的鲁棒性。
  • 可以适用于多分类问题。 KNN算法的缺点包括:
  • 计算复杂度高:在进行预测时,需要计算测试样本与所有训练样本的距离,计算量较大。
  • 对数据分布敏感:如果训练样本不平衡或噪声较多,可能会导致预测结果不准确。
  • 需要选择合适的K值:K值的选择会影响算法的性能,需要通过交叉验证等方法进行调优。 在实际应用中,KNN算法常用于数据集较小、特征维度较低的场景。它可以用于图像识别、推荐系统、文本分类等任务。此外,KNN算法还可以通过一些改进方法,如KD树、球树等来提高算法的效率。

下面是一个使用Python实现KNN算法的示例代码:

pythonCopy codeimport numpy as np
from collections import Counter
class KNN:
    def __init__(self, k):
        self.k = k
    
    def euclidean_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2) ** 2))
    
    def fit(self, X, y):
        self.X_train = X
        self.y_train = y
    
    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)
    
    def _predict(self, x):
        distances = [self.euclidean_distance(x, x_train) for x_train in self.X_train]
        k_indices = np.argsort(distances)[:self.k]
        k_nearest_labels = [self.y_train[i] for i in k_indices]
        most_common = Counter(k_nearest_labels).most_common(1)
        return most_common[0][0]

使用示例代码,可以按照以下步骤进行:

pythonCopy code# 创建一个KNN对象,设置k值为3
knn = KNN(k=3)
# 准备训练数据
X_train = np.array([[1, 2], [1, 4], [2, 2], [3, 4]])
y_train = np.array([0, 0, 1, 1])
# 拟合模型
knn.fit(X_train, y_train)
# 准备测试数据
X_test = np.array([[2, 3], [3, 3]])
# 进行预测
y_pred = knn.predict(X_test)
# 输出预测结果
print(y_pred)

输出结果将会是预测的标签数组 ​​[0, 1]​​。这表示模型预测第一个测试样本属于类别0,第二个测试样本属于类别1。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。