对西瓜数据集4.0,假定k=3,k均值算法的python代码
时间: 2023-12-10 11:39:40 浏览: 65
以下是使用Python实现k均值算法的代码,假设数据集已经被加载到名为X的numpy数组中:
```python
import numpy as np
def dist(a, b):
# 计算两个向量之间的欧几里得距离
return np.linalg.norm(a - b)
def kmeans(X, k):
# 随机初始化k个聚类中心
centers = X[np.random.choice(np.arange(len(X)), k, replace=False)]
old_centers = np.zeros(centers.shape)
labels = np.zeros(len(X))
iterations = 0
# 当聚类中心不再改变时停止迭代
while not np.array_equal(centers, old_centers):
iterations += 1
old_centers = centers.copy()
# 分配每个数据点到最近的聚类中心
for i, x in enumerate(X):
distances = [dist(x, center) for center in centers]
labels[i] = np.argmin(distances)
# 更新聚类中心
for i in range(k):
cluster_points = [X[j] for j in range(len(X)) if labels[j] == i]
centers[i] = np.mean(cluster_points, axis=0)
return labels, centers, iterations
```
使用该函数可以得到每个数据点的聚类标签,以及聚类中心和迭代次数。例如,可以使用以下代码来对西瓜数据集进行聚类:
```python
import pandas as pd
# 加载数据集
data = pd.read_csv("watermelon_4.0.csv")
# 选择需要聚类的特征(密度和含糖率)
X = data[["density", "sugar_content"]].values
# 调用kmeans函数进行聚类
labels, centers, iterations = kmeans(X, k=3)
# 输出聚类结果
print("聚类标签:", labels)
print("聚类中心:", centers)
print("迭代次数:", iterations)
```
阅读全文