使用python实现k均值聚类
时间: 2023-08-31 11:12:40 浏览: 72
课程实验基于python实现K均值(k-means)聚类算法源码+项目说明.zip
5星 · 资源好评率100%
以下是使用Python实现k均值聚类的示例代码:
``` python
import numpy as np
class KMeans:
def __init__(self, k=3, max_iter=100):
self.k = k
self.max_iter = max_iter
def fit(self, X):
self.centroids = X[np.random.choice(range(len(X)), self.k, replace=False)]
for i in range(self.max_iter):
clusters = [[] for _ in range(self.k)]
for x in X:
distances = [np.linalg.norm(x - c) for c in self.centroids]
cluster_idx = np.argmin(distances)
clusters[cluster_idx].append(x)
prev_centroids = self.centroids
self.centroids = [np.mean(cluster, axis=0) for cluster in clusters]
if np.all(prev_centroids == self.centroids):
break
def predict(self, X):
distances = [np.linalg.norm(X - c, axis=1) for c in self.centroids]
cluster_idx = np.argmin(distances, axis=0)
return cluster_idx
```
在这个实现中,我们首先定义了一个KMeans类,它有两个参数:k是簇的数量,max_iter是最大迭代次数。在fit方法中,我们首先随机选择k个初始聚类中心,然后循环执行以下步骤:
1. 将每个样本分配到与其最近的聚类中心所在的簇中。
2. 计算每个簇的新聚类中心,并更新聚类中心。
3. 如果更新后的聚类中心与上一次迭代的聚类中心相同,则停止迭代。
在predict方法中,我们计算每个样本到每个聚类中心的距离,并返回它们所属的簇的索引。
下面是一个使用示例:
``` python
import matplotlib.pyplot as plt
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
kmeans = KMeans(k=2)
kmeans.fit(X)
labels = kmeans.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels)
plt.scatter(kmeans.centroids[:, 0], kmeans.centroids[:, 1], marker='x', s=200, linewidths=3, color='r')
plt.show()
```
这个例子中,我们生成了一个二维数据集,并使用KMeans类将其聚类为两个簇。最后,我们使用matplotlib库绘制了聚类结果。
阅读全文