Python 实现 KNN 算法

举报
TiAmoZhang 发表于 2023/05/31 10:05:35 2023/05/31
【摘要】 本篇我们将讨论一种广泛使用的分类技术,称为k邻近算法,或者说K最近邻(KNN,k-Nearest Neighbor)。所谓K最近邻,是k个最近的邻居的意思,即每个样本都可以用它最接近的k个邻居来代表。

本篇我们将讨论一种广泛使用的分类技术,称为 k 邻近算法,或者说 K 最近邻(KNN,k-Nearest Neighbor)。所谓 K 最近邻,是 k 个最近的邻居的意思,即每个样本都可以用它最接近的 k 个邻居来代表。

01、KNN 算法思想

如果一个样本在特征空间中的 k 个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN 方法在类别决策时,只与极少量的相邻样本有关。


由于 KNN 方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN 方法较其他方法更为适合。

02、KNN 算法的决策过程

下图中有两种类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形,中间那个绿色的圆形是待分类数据:

图片

近邻分类图

如果K=3,那么离绿色点最近的有2个红色的三角形和1个蓝色的正方形,这三个点进行投票,于是绿色的待分类点就属于红色的三角形。而如果K=5,那么离绿色点最近的有2个红色的三角形和3个蓝色的正方形,这五个点进行投票,于是绿色的待分类点就属于蓝色的正方形。 


KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比。


下面用代码来实现KNN算法的应用。本次用到的数据是经典的Iris数据集。该数据集有150条鸢尾花数据样本,并且均匀分布在3个不同的亚种:每个数据样本被4个不同的花瓣、花萼的形状特征所描述。

#读取数据
from sklearn.datasets import load_iris
data = load_iris()
#查看数据大小
data.data.shape
(150, 4)
#查看数据说明
print (data.DESCR)
Notes
-----
Data Set Characteristics:
    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988
This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris
The famous Iris database, first used by Sir R.A Fisher
This is perhaps the best known database to be found in the pattern recognition literature.  Fisher's paper is a classic in the field and is referenced frequently to this day.  (See Duda & Hart, for example.)  The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.  One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.

References
----------
   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

通过上述代码对数据的查验以及数据本身的描述,我们可以了解到Iris数据集共有150条鸢尾花数据样本,并且均匀分布在3个不同的亚种;每一个数据样本被4个不同的花瓣、花萼的形状特征所描述。由于没有指定的测试集,依据管理,我们需要第数据进行随机分割,25%的数据用作测试,75的数据用作训练。


需要强调的是,如果读者朋友自行编写程序用作数据分割,请务必保证是随机采样。尽管很多数据集中的样本的排序相对随机,但是也有例外。本例中,Iris数据就是根据类别一次排列的。如果只采样前25%的数据用作测试,那么所有的测试样本都属于一个类别,同时训练样本也是不均衡的,这样得到的结果存在偏置,并且可信度非常低,Scikit-learn所提供的数据分割模块是默认采用随机采样的功能的,因此大家可不必担心。

#对数据进行分割
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size = 0.25, random_state = 33)

#使用KNN算法进行分类
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
#初始化
ss = StandardScaler()

#数据标准化
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)

#训练模型
knc = KNeighborsClassifier()
knc.fit(X_train, y_train)
#预测
y_pred = knc.predict(X_test)

#模型评估
print ('The accuracy of KNN is:', knc.score(X_test, y_test))
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names = data.target_names))

代码输出结果如下,Knn算法对鸢尾花测试数据的分类准确率为89.474%,其他数据如下可见。

图片

KNN算法的特点分析:KNN算法是非常直观的机器学习模型,因此深受广大初学者的喜爱。许多教科书往往一次模型抛砖引玉,便足以看出其不仅特别,而且尚有瑕疵之处。细心的读者会发现,KNN算法与其他算法模型最大的不同在于:该模型没有参数训练过程。也就是说,我们并没有通过任何学习算法来分析训练数据,而只是根据测试样本在训练数据中的的分布直接做出分类决策。因此,KNN算法属于无参数模型中非常简单的一种。然而,正是这样的决策算法,导致了其非常高的计算复杂度和内存消耗。因为该模型每处理一个测试样本,都需要对所有事先加载在内存中的训练样本进行遍历、逐一计算相似度、排序并且选取K个最近邻训练样本的标记,进而做出分类决策。这是平方级的算法复杂度,一旦数据规模稍大,便需要权衡更多计算时间的代价。


最后,对KNN算法做一个简单的小结:

优点

简单,易于理解,易于实现,无需估计参数,无需训练;

适合对稀有事件进行分类;

特别适合于多分类问题(multi-modal,对象具有多个类别标签),kNN比SVM的表现要好。

缺点

当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数,少数类容易分错。

需要存储全部训练样本。

计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。

可理解性差,无法给出像决策树那样的规则。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。