生成em算法混用到高斯混合模型的代码
时间: 2023-10-23 07:14:15 浏览: 128
好的,下面是使用EM算法进行高斯混合模型拟合的Python代码示例:
```
import numpy as np
from scipy.stats import multivariate_normal
class GMM:
def __init__(self, n_components=1, max_iter=100, tol=1e-4):
self.n_components = n_components # 高斯分布个数
self.max_iter = max_iter # 最大迭代次数
self.tol = tol # 迭代停止阈值
self.alpha = None # 每个高斯分布的权重
self.mu = None # 每个高斯分布的均值
self.sigma = None # 每个高斯分布的协方差矩阵
def fit(self, X):
n_samples, n_features = X.shape
self.alpha = np.ones(self.n_components) / self.n_components # 初始化权重
self.mu = np.random.randn(self.n_components, n_features) # 随机初始化均值
self.sigma = np.array([np.eye(n_features)] * self.n_components) # 初始化协方差矩阵
for _ in range(self.max_iter):
# E-step
gamma = np.zeros((n_samples, self.n_components))
for k in range(self.n_components):
gamma[:, k] = self.alpha[k] * multivariate_normal.pdf(X, self.mu[k], self.sigma[k])
gamma /= gamma.sum(axis=1, keepdims=True)
# M-step
Nk = gamma.sum(axis=0)
self.alpha = Nk / n_samples
self.mu = gamma.T @ X / Nk[:, np.newaxis]
for k in range(self.n_components):
diff = X - self.mu[k]
self.sigma[k] = (diff.T @ (diff * gamma[:, k, np.newaxis])) / Nk[k]
# 检查收敛
if np.abs(gamma - gamma.mean()).mean() < self.tol:
break
def predict(self, X):
gamma = np.zeros((X.shape[0], self.n_components))
for k in range(self.n_components):
gamma[:, k] = self.alpha[k] * multivariate_normal.pdf(X, self.mu[k], self.sigma[k])
return gamma.argmax(axis=1)
```
这里实现了一个高斯混合模型类`GMM`,其中`n_components`表示高斯分布个数,`max_iter`表示最大迭代次数,`tol`表示迭代停止阈值。在`fit`函数中,首先对权重、均值和协方差矩阵进行初始化,然后进行EM算法的迭代,分别实现了E-step和M-step,最后检查迭代是否收敛。在`predict`函数中,返回每个样本属于哪个高斯分布的概率最大。
需要注意的是,这里使用了多元高斯分布函数`multivariate_normal`,它需要先进行导入:
```
from scipy.stats import multivariate_normal
```
使用示例:
```
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
# 生成数据集
X, y = make_blobs(n_samples=1000, centers=3, random_state=42)
# 拟合高斯混合模型
gmm = GMM(n_components=3)
gmm.fit(X)
# 可视化聚类结果
colors = ['r', 'g', 'b']
for k in range(gmm.n_components):
plt.scatter(X[gmm.predict(X) == k, 0], X[gmm.predict(X) == k, 1], c=colors[k], alpha=0.5)
plt.show()
```
这里使用了`make_blobs`函数生成了一个二维数据集,然后对数据集进行高斯混合模型拟合,并可视化聚类结果。
阅读全文