对西瓜数据集4.0,假定k=3,k均值算法的python代码,输出散点图表示各轮迭代后的簇划分
时间: 2023-10-22 17:09:24 浏览: 104
python散点图代码
以下是基于scikit-learn库实现k-means算法并输出散点图的Python代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
# 载入数据集
data = load_iris().data[:, 2:4]
# 模型训练
kmeans = KMeans(n_clusters=3, random_state=0).fit(data)
# 预测类别
labels = kmeans.predict(data)
# 绘制散点图
plt.scatter(data[:, 0], data[:, 1], c=labels)
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], marker='x', s=200, linewidths=3, color='r')
plt.title("K-means Clustering Results")
plt.show()
```
解释一下代码:
首先,我们使用`load_iris()`方法加载鸢尾花数据集,并选择其中的后两个特征(即花瓣长度和花瓣宽度)作为样本特征。
接着,我们使用`KMeans()`方法创建一个KMeans模型,指定簇的个数为3,然后调用`fit()`方法对模型进行训练。
训练完成后,我们调用`predict()`方法预测每个样本所属的簇,并将结果保存在`labels`变量中。
最后,我们使用`scatter()`方法绘制散点图,其中`c`参数表示每个样本的颜色,`cluster_centers_`属性表示每个簇的中心点,我们用红色的叉号表示中心点。
阅读全文