EM算法与高斯混合聚类:理解与实践

举报
小馒头学Python 发表于 2024/11/22 22:11:20 2024/11/22
【摘要】 🍋引言在数据科学和机器学习中,聚类是重要的无监督学习任务。高斯混合模型(GMM)是一种常用的概率模型,用于描述数据的分布。在应用高斯混合模型时,EM(Expectation-Maximization)算法被广泛用于参数估计。本文将深入探讨EM算法的基本原理,并结合高斯混合模型,展示如何实现基于EM算法的聚类。 🍋什么是EM算法? 🍋EM算法的基本原理EM算法是一种迭代算法,广泛用于估...

🍋引言

在数据科学和机器学习中,聚类是重要的无监督学习任务。高斯混合模型(GMM)是一种常用的概率模型,用于描述数据的分布。在应用高斯混合模型时,EM(Expectation-Maximization)算法被广泛用于参数估计。本文将深入探讨EM算法的基本原理,并结合高斯混合模型,展示如何实现基于EM算法的聚类。

🍋什么是EM算法?

🍋EM算法的基本原理

EM算法是一种迭代算法,广泛用于估计含有隐变量(latent variables)的概率模型的参数。它主要由两步组成:

  • E步(期望步,Expectation step):根据当前模型的参数,计算隐变量的期望值。
  • M步(最大化步,Maximization step):根据E步的结果,最大化似然函数,更新模型的参数。

EM算法通过不断重复E步和M步,逐步逼近最大似然估计。

🍋EM算法的应用领域

EM算法不仅在高斯混合模型中有广泛应用,还在许多其他领域也有应用,例如:

  • 高斯混合模型(GMM)
  • 隐马尔可夫模型(HMM)
  • 聚类分析
  • 图像分割

🍋高斯混合模型(GMM)简介

🍋GMM模型的定义

高斯混合模型是一种假设数据点是由多个高斯分布成分组成的概率模型。每个高斯成分有自己的均值、方差和权重。GMM是通过EM算法来估计这些参数的。

GMM的概率密度函数可以表示为:

image.png

🍋GMM在聚类中的作用

在聚类问题中,GMM通过拟合多个高斯分布来表示不同的聚类中心,数据点的归属通过计算其属于各个高斯成分的概率来确定。与K-means算法相比,GMM可以捕捉数据的多模态特性,不仅仅是基于距离的硬分类,还能通过概率分配进行软分类。

🍋EM算法与高斯混合聚类的结合

🍋如何用EM算法训练GMM

使用EM算法训练高斯混合模型时,主要目标是最大化数据点在模型下的对数似然函数。每次迭代中,E步通过计算数据点属于每个高斯成分的概率,M步则更新模型参数,使得对数似然函数最大化。

具体步骤如下:

  • 初始化:初始化高斯成分的均值、协方差矩阵和权重。
  • E步:计算每个数据点属于每个高斯成分的概率,称为责任度(responsibility)。
  • M步:根据E步的结果,更新均值、协方差矩阵和权重。
  • 重复:重复E步和M步,直到对数似然函数收敛。

🍋GMM在聚类中的实际应用

GMM广泛应用于图像处理、文本分析、市场细分等领域。在聚类任务中,GMM可以帮助发现数据中的潜在模式,并且相比于传统的K-means算法,它能够更好地处理复杂的分布。

🍋Python实现EM算法与高斯混合聚类

🍋导入必要的库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs

🍋生成模拟数据

我们使用make_blobs生成带有多个聚类的模拟数据:

# 生成模拟数据
X, y = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
plt.scatter(X[:, 0], X[:, 1], s=30)
plt.title("Generated Data")
plt.show()

🍋应用GMM进行聚类

我们使用GaussianMixture类来拟合GMM模型:

# 使用GMM进行聚类
gmm = GaussianMixture(n_components=4)
gmm.fit(X)
labels = gmm.predict(X)

# 可视化聚类结果
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.title("GMM Clustering Results")
plt.show()

🍋可视化结果

# 可视化每个聚类的高斯分布
ax = plt.gca()
ax.set_title("GMM Clustering with Gaussian Components")

# 绘制数据点
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')

# 绘制每个高斯分布的轮廓
for i in range(4):
    covariances = gmm.covariances_[i]
    mean = gmm.means_[i]
    v, w = np.linalg.eigh(covariances)
    v = 2.0 * np.sqrt(2.0) * np.sqrt(v)  # Elongate the axes by a factor of 2
    u = w[0] / np.linalg.norm(w[0])  # Normalize the eigenvector
    angle = np.arctan(u[1] / u[0])  # Rotation angle

    # Plot ellipse for each Gaussian component
    ellipse = plt.matplotlib.patches.Ellipse(mean, v[0], v[1], 180.0 * angle / np.pi, color='red', alpha=0.4)
    ax.add_patch(ellipse)

plt.show()

🍋完整源码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs

# 生成模拟数据
X, y = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)

# 可视化生成的数据
plt.scatter(X[:, 0], X[:, 1], s=30)
plt.title("Generated Data")
plt.show()

# 使用GMM进行聚类
gmm = GaussianMixture(n_components=4)
gmm.fit(X)
labels = gmm.predict(X)

# 可视化聚类结果
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.title("GMM Clustering Results")
plt.show()

# 可视化每个聚类的高斯分布
ax = plt.gca()
ax.set_title("GMM Clustering with Gaussian Components")

# 绘制数据点
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')

# 绘制每个高斯分布的轮廓
for i in range(4):
    covariances = gmm.covariances_[i]
    mean = gmm.means_[i]
    v, w = np.linalg.eigh(covariances)
    v = 2.0 * np.sqrt(2.0) * np.sqrt(v)  # Elongate the axes by a factor of 2
    u = w[0] / np.linalg.norm(w[0])  # Normalize the eigenvector
    angle = np.arctan(u[1] / u[0])  # Rotation angle

    # 修改部分:修正了 Ellipse 参数传递
    ellipse = plt.matplotlib.patches.Ellipse(mean, v[0], v[1], angle=angle * 180.0 / np.pi, color='red', alpha=0.4)
    ax.add_patch(ellipse)

plt.show()

可视化:

  • 第一张图(生成的数据图)展示了数据的原始分布,帮助我们理解数据的聚类结构。
  • 第二张图(GMM聚类结果图)展示了使用 GMM 进行聚类的结果,表示数据点的分类情况。
  • 第三张图(每个聚类的高斯分布轮廓图)展示了每个聚类对应的高斯分布的轮廓,帮助我们更深入地理解 GMM 对数据的建模方式,展示了每个聚类的协方差结构。
    image.png

image.png

image.png

🍋实践案例(Wine数据集)

Wine 数据集概况
样本数:178
类别数:3
特征数:13

import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_wine
from sklearn.decomposition import PCA

# 加载Wine数据集
wine = load_wine()
X = wine.data  # 特征数据
y = wine.target  # 标签(真实类别)

# 使用PCA进行降维,便于在二维平面中可视化
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

# 使用GMM进行聚类
gmm = GaussianMixture(n_components=3, random_state=0)
gmm.fit(X_pca)
labels = gmm.predict(X_pca)

# 1. 原始数据的分布图
plt.figure(figsize=(6, 6))
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis', s=50, edgecolor='k', alpha=0.7)
plt.title("Wine Dataset - Original Data Distribution")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.colorbar(label='True Labels')
plt.show()

# 2. GMM 聚类结果图
plt.figure(figsize=(6, 6))
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=labels, cmap='viridis', s=50, edgecolor='k', alpha=0.7)
plt.title("GMM Clustering Results")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.colorbar(label='Cluster Labels')
plt.show()

# 3. GMM 每个聚类的高斯分布轮廓图
plt.figure(figsize=(6, 6))
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=labels, cmap='viridis', s=50, edgecolor='k', alpha=0.7)
plt.title("GMM Clustering with Gaussian Components")

# 绘制每个高斯成分的轮廓
ax = plt.gca()
for i in range(3):  # 有3个聚类
    mean = gmm.means_[i]
    cov = gmm.covariances_[i]
    v, w = np.linalg.eigh(cov)
    v = 2.0 * np.sqrt(2.0) * np.sqrt(v)  # Elongate the axes by a factor of 2
    u = w[0] / np.linalg.norm(w[0])  # Normalize the eigenvector
    angle = np.arctan(u[1] / u[0])  # Rotation angle

    # 绘制椭圆
    ellipse = plt.matplotlib.patches.Ellipse(mean, v[0], v[1], angle=angle * 180.0 / np.pi, color='red', alpha=0.4)
    ax.add_patch(ellipse)

plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.colorbar(label='Cluster Labels')
plt.show()
  • 原始数据分布图:展示了真实类别(酒的种类)的数据分布情况。
  • GMM 聚类结果图:展示了 GMM 聚类后数据的分配情况,聚类标签可能与真实标签不完全匹配,因为 GMM 是无监督学习方法。
  • 每个聚类的高斯分布轮廓图:展示了每个聚类的高斯分布模型的轮廓(通过椭圆表示),帮助理解 GMM 如何对数据建模。
    image.png

image.png

image.png

🍋总结

通过本文的介绍,我们了解了EM算法的基本原理,并结合高斯混合模型(GMM)展示了如何使用EM算法进行聚类。与传统的K-means聚类相比,GMM能够提供更精确的结果,尤其是在数据分布不规则或复杂时。随着算法的不断发展和优化,GMM在实际应用中将展现出更大的潜力。

🍋参考文献

【1】Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer.

  • 链接:https://www.papiro-bookstore.com/wp-content/uploads/2021/12/Pattern-Recognition-and-Machine-Learning.pdf
  • 书籍概述:这本书是由 Christopher M. Bishop 编写的,涵盖了模式识别和机器学习的许多基础概念。书中的内容包括了监督学习、无监督学习、图模型、神经网络、贝叶斯方法等。特别是在机器学习和模式识别的应用中,Bishop 的这本书被广泛引用,尤其适用于研究统计学和概率推理方法在模式识别中的应用。
  • 主要贡献:该书深入介绍了机器学习的理论基础和实际算法,尤其强调了统计模型和概率方法对模式识别问题的解决方案。
    【2】Murphy, K. P. (2012). Machine Learning: A Probabilistic Perspective. MIT Press.
  • 链接:https://books.google.com.sg/books?hl=zh-CN&lr=&id=RC43AgAAQBAJ&oi=fnd&pg=PR7&dq=Murphy,+K.+P.+(2012).+Machine+Learning:+A+Probabilistic+Perspective.+MIT+Press.&ots=univgzLvY4&sig=0QeEPELLTyhuc_Y1cHs6em6A29o&redir_esc=y#v=onepage&q=Murphy%2C%20K.%20P.%20(2012).%20Machine%20Learning%3A%20A%20Probabilistic%20Perspective.%20MIT%20Press.&f=false
  • 书籍概述:这本书由 Kevin P. Murphy 编写,是一本关于机器学习的经典教材,重点介绍了机器学习中的概率模型。书中内容不仅详细阐述了机器学习的基础算法,还深入讨论了各种基于概率的学习方法,如贝叶斯网络、隐马尔可夫模型、图模型等。它对概率推理和贝叶斯学习有深入的阐释。
  • 主要贡献:Murphy 的这本书通过严格的数学推导,帮助读者理解机器学习算法背后的概率基础,适合有一定统计学和数学基础的读者。
    【3】Dempster, A. P., Laird, N. M., & Rubin, D. B. (1977). Maximum likelihood from incomplete data via the EM algorithm. Journal of the Royal Statistical Society. Series B (Methodological), 39(1), 1-38.
  • 链接:https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1977.tb01600.x
  • 文章概述:这篇文章是 Dempster, Laird 和 Rubin 在 1977 年发表的,介绍了 EM算法(期望最大化算法,Expectation-Maximization Algorithm)。EM算法是一种用于估计包含隐变量的概率模型参数的统计方法,广泛应用于缺失数据的最大似然估计。EM算法的核心思想是通过交替执行期望步(E步)和最大化步(M步),逐步优化参数。
  • 主要贡献:EM算法为解决许多实际问题中的参数估计问题提供了强有力的工具,尤其是在存在缺失数据或隐变量的情况下。这篇论文奠定了 EM 算法在统计学和机器学习领域中的基础地位,是该领域的开创性工作。

image.png

挑战与创造都是很痛苦的,但是很充实。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。