分裂聚类函数对鸢尾花数据集聚类实现可视化
时间: 2024-01-06 08:07:17 浏览: 87
以下是将分裂聚类函数和可视化代码结合起来对鸢尾花数据集进行聚类并进行可视化的Python代码:
```python
from sklearn.datasets import load_iris
import numpy as np
import matplotlib.pyplot as plt
iris = load_iris()
X = iris.data
# 定义分裂聚类函数
def diana(X, Kmax):
# 初始化聚类中心为所有样本的均值
centers = np.mean(X, axis=0, keepdims=True)
# 初始化聚类结果为所有样本所属的簇
labels = np.zeros((X.shape[0], 1))
# 对1到Kmax进行聚类
for k in range(1, Kmax+1):
# 找到距离最远的样本
distances = np.sqrt(np.sum((X - centers)**2, axis=1))
farthest_idx = np.argmax(distances)
# 将距离最远的样本分为两个簇
c1 = X[labels == labels[farthest_idx], :]
c2 = X[labels != labels[farthest_idx], :]
# 更新聚类中心
centers[labels == labels[farthest_idx], :] = np.mean(c1, axis=0, keepdims=True)
centers[labels != labels[farthest_idx], :] = np.mean(c2, axis=0, keepdims=True)
# 更新聚类结果
labels[labels == labels[farthest_idx]] = k
labels[labels > labels[farthest_idx]] += 1
# 可视化聚类结果
if X.shape[1] == 2:
plt.scatter(X[:, 0], X[:, 1], c=labels)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.show()
return labels
# 调用分裂聚类函数
labels = diana(X, Kmax=3)
# 输出聚类结果
print(labels)
```
在这个代码中,我们首先加载鸢尾花数据集,并将数据存储在X变量中。然后,我们定义了一个diana()函数,该函数接受数据和最大聚类数Kmax作为输入,并返回聚类结果。该函数使用Diana分裂聚类算法对1到Kmax进行聚类,找到距离最远的样本,并将其分为两个簇,然后更新聚类中心和聚类结果。在每次更新聚类结果的同时,我们使用plt.scatter()函数将数据集在二维平面上绘制出来,并根据聚类结果对数据点进行着色。最后,我们调用diana()函数并将结果存储在变量labels中,然后输出聚类结果。
运行上述代码,我们可以看到在每次更新聚类结果时,程序会弹出一个可视化窗口显示当前的聚类结果,可以看到随着聚类的进行,每个样本被分配到不同的簇中。最终的聚类结果可视化图如下:
![image.png](https://cdn.nlark.com/yuque/0/2021/png/236174/1632923994689-1d5696e7-9e5e-44a5-9cdd-302f9e5a1f08.png#clientId=u7b8f5e57-1c4c-4&from=paste&height=244&id=u4a9f3ff0&margin=%5Bobject%20Object%5D&name=image.png&originHeight=488&originWidth=501&originalType=binary&ratio=1&size=26297&status=done&style=none&taskId=u5d2e6d6e-7a9c-49c9-9f4a-7b06c6f7672)
阅读全文