em算法python代码
时间: 2023-08-06 10:22:09 浏览: 197
以下是一个简单的 EM 算法的 Python 代码示例:
```
import numpy as np
def gaussian(x, mu, sigma):
return np.exp(-0.5 * ((x - mu) / sigma) ** 2) / (sigma * np.sqrt(2 * np.pi))
def em_algorithm(data, num_clusters, max_iterations):
num_samples, dim = data.shape
# Initialize means and covariance matrices randomly
means = np.random.randn(num_clusters, dim)
covariances = np.zeros((num_clusters, dim, dim))
for i in range(num_clusters):
covariances[i] = np.diag(np.random.rand(dim))
# Initialize mixing coefficients uniformly
mix_coefficients = np.ones(num_clusters) / num_clusters
for iteration in range(max_iterations):
# E-step: compute responsibilities
responsibilities = np.zeros((num_samples, num_clusters))
for i in range(num_samples):
for j in range(num_clusters):
responsibilities[i, j] = mix_coefficients[j] * gaussian(data[i], means[j], covariances[j]).sum()
responsibilities[i] /= responsibilities[i].sum()
# M-step: update parameters
for j in range(num_clusters):
# Update mean
means[j] = (responsibilities[:, j] * data.T).sum(axis=1) / responsibilities[:, j].sum()
# Update covariance
diff = data - means[j]
covariances[j] = (responsibilities[:, j] * diff.T @ diff) / responsibilities[:, j].sum()
# Update mixing coefficient
mix_coefficients[j] = responsibilities[:, j].sum() / num_samples
return means, covariances, mix_coefficients
```
其中,`data` 是一个 $N\times D$ 的矩阵,表示 $N$ 个 $D$ 维数据点;`num_clusters` 是聚类数量;`max_iterations` 是最大迭代次数。函数返回每个簇的均值、协方差矩阵和混合系数。
阅读全文