k均值聚类代码
时间: 2023-07-10 08:10:30 浏览: 98
以下是Python实现的k均值聚类算法代码:
```Python
import numpy as np
class KMeans:
def __init__(self, n_clusters, max_iter=300):
self.n_clusters = n_clusters
self.max_iter = max_iter
def fit(self, X):
self.centroids = X[np.random.choice(len(X), self.n_clusters, replace=False)]
for i in range(self.max_iter):
clusters = [[] for _ in range(self.n_clusters)]
for x in X:
distances = [np.linalg.norm(x - c) for c in self.centroids]
cluster_index = np.argmin(distances)
clusters[cluster_index].append(x)
new_centroids = []
for cluster in clusters:
if len(cluster) == 0:
new_centroids.append(np.random.uniform(-1, 1, X.shape[1]))
else:
new_centroids.append(np.mean(cluster, axis=0))
if np.allclose(self.centroids, new_centroids):
break
self.centroids = new_centroids
def predict(self, X):
clusters = [[] for _ in range(self.n_clusters)]
for x in X:
distances = [np.linalg.norm(x - c) for c in self.centroids]
cluster_index = np.argmin(distances)
clusters[cluster_index].append(x)
return clusters
```
其中,fit方法用于训练模型,predict方法用于对新数据进行聚类。具体而言,fit方法首先随机选取k个数据点作为初始聚类中心,然后重复以下步骤:
1. 将所有数据点分配到距离最近的聚类中心所在的簇中。
2. 计算每个簇的中心(即新的聚类中心)。
3. 如果新的聚类中心与旧的聚类中心非常接近,则停止迭代。
最后,predict方法根据已经训练好的模型将新数据点分配到相应的簇中,返回每个簇中的数据点列表。
阅读全文