不调用sklearn库,用python编写高斯混合聚类算法,并在iris数据上进行聚类实验,最后可视化
时间: 2024-03-21 16:43:40 浏览: 71
好的,以下是使用Python编写高斯混合聚类算法,并在Iris数据上进行聚类实验的代码。我们先从网上下载Iris数据集,并进行数据预处理:
```python
import numpy as np
# 读取数据集
with open('iris.data', 'r') as f:
data = f.readlines()
# 数据预处理
X = []
for i in range(len(data)):
if data[i] != '\n':
x = data[i].split(',')[:-1]
X.append(list(map(float, x)))
X = np.array(X)
```
然后,我们实现高斯混合聚类算法:
```python
from scipy.stats import multivariate_normal
def GMM(X, K, max_iter):
n, m = X.shape
# 初始化参数
pi = np.full(shape=K, fill_value=1/K)
means = np.random.rand(K, m)
covs = np.array([np.eye(m)] * K)
for i in range(max_iter):
# E步骤
gamma = np.zeros((n, K))
for j in range(K):
gamma[:, j] = pi[j] * multivariate_normal.pdf(X, means[j], covs[j])
gamma = (gamma.T / np.sum(gamma, axis=1)).T
# M步骤
Nk = np.sum(gamma, axis=0)
pi = Nk / n
for j in range(K):
means[j] = np.sum(X * gamma[:, j].reshape(-1, 1), axis=0) / Nk[j]
x_mean = (X - means[j])
covs[j] = np.dot(gamma[:, j] * x_mean.T, x_mean) / Nk[j]
return np.argmax(gamma, axis=1)
```
最后,我们可视化聚类效果:
```python
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 进行聚类并可视化
labels = GMM(X, 3, 100)
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=labels)
plt.show()
```
运行完整代码后,将会得到Iris数据集的聚类可视化结果。
阅读全文