Python从0到100(五十六):机器学习-K均值聚类鸢尾花数据集聚类
K均值聚类是⼀种常⽤的⽆监督学习算法,⽤于将数据集分成 K 个不同的类别(簇),使得同⼀类别内的样本点彼此距离最近,不同类别之间的样本点距离最远。K均值聚类算法
通过迭代优化来实现聚类,是⼀种简单⽽有效的聚类算法。
1.基本原理
K均值聚类的基本原理如下:
1、随机初始化:⾸先,随机选择K个数据点作为初始簇中⼼点。
2、分配数据点:对于每个数据点,计算其与各个簇中⼼点的距离,并将其分配到距离最近的簇中⼼点所属的簇。
3、更新簇中⼼点:对于每个簇,计算该簇内所有数据点的平均值,将其作为新的簇中⼼点。
4、重复迭代:重复步骤2和步骤3,直到簇中⼼点不再发⽣明显变化,或者达到预定的迭代次数。
2.公式模型
K均值聚类的核⼼公式包括计算样本点到聚类中⼼的距离以及更新聚类中⼼的公式。具体⽽⾔,距离的计算通常采⽤欧式距离:
其中,xi是样本点,cj是聚类中⼼,n是特征的数量。
推导K均值聚类的过程涉及到对样本点进⾏聚类并更新聚类中⼼,通过最⼩化每个类别内样本点到聚类中⼼的距离来优化聚类结果。
K均值聚类的结果是将数据点划分为K个簇,并且每个簇由⼀个中⼼点表示。这个算法通常需要多次运⾏,并在不同的初始簇中⼼点选择下进⾏迭代,以找到全局最优解。K均值聚类是⼀种简单⽽有效的聚类算法,但它对初始簇中⼼点的选择敏感,并且需要指定K的值。
3.优缺点
优点:
- 简单⾼效:K均值聚类算法简单易懂,计算效率⾼。
- 可扩展性强:K均值聚类算法适⽤于⼤规模数据集,并且可以⽅便地进⾏分布式计算。
- 容易解释结果:K均值聚类产⽣的聚类结果直观,易于解释和理解。
缺点:
- 对初始聚类中⼼敏感:K均值聚类对初始聚类中⼼的选择较为敏感,不同的初始值可能导致不同的聚类结果。
- 需要指定聚类数K:K均值聚类算法需要事先指定聚类数K,对K的选择需要⼀定的领域知识或者通过试验确定。
- 对异常值敏感:K均值聚类对异常值较为敏感,可能会影响聚类结果的准确性。
4.适用场景
K均值聚类适⽤于以下场景:
- 数据集中类别的数量已知或者可以通过领域知识估计。
- 数据集的特征相对均匀分布,各个类别内样本点的⽅差相差不⼤。
- 需要快速对⼤规模数据集进⾏聚类。
K均值聚类是⼀种简单⽽有效的聚类算法,尤其适⽤于类别数量已知且数据集相对均匀分布的情况。然⽽,在处理异常值和需要确定聚类数量的情况下,K均值聚类的性能可能会受到⼀定影响。
5.鸢尾花数据集聚类
基于开源数据集的K均值聚类实例代码,使⽤鸢尾花数据集(Iris dataset)进⾏聚类,并展示聚类结果的可视化:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
# 特征标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 使⽤PCA进⾏降维
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
# 构建K均值聚类模型
kmeans = KMeans(n_clusters=3, random_state=42)
kmeans.fit(X_scaled)
# 获取聚类中⼼和预测类别
cluster_centers = kmeans.cluster_centers_
y_pred = kmeans.labels_
# 可视化聚类结果
plt.figure(figsize=(10, 8))
# 绘制原始数据的散点图
plt.subplot(2, 1, 1)
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis', s=50, alpha=0.8)
plt.title('Original Data')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
# 绘制聚类结果的散点图
plt.subplot(2, 1, 2)
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y_pred, cmap='viridis', s=50, alpha=0.8)
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], c='red', marker='x', s=200,
label='Cluster Centers')
plt.title('K-Means Clustering')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.tight_layout()
plt.show()
⾸先加载鸢尾花数据集,并对数据进⾏了特征标准化和降维(使⽤PCA进⾏降维)。然后构建了⼀个K均值聚类模型,并在降维后的数据上进⾏了聚类。最后,通过绘制散点图展示了原始数据和聚类结果。
- 点赞
- 收藏
- 关注作者
评论(0)