写一个python实现的聚类算法
时间: 2024-05-12 17:15:32 浏览: 6
这里给出一个基于KMeans算法的聚类实现:
```python
import numpy as np
import matplotlib.pyplot as plt
class KMeans:
def __init__(self, k=2, max_iter=100):
self.k = k
self.max_iter = max_iter
def fit(self, X):
self.centroids = X[np.random.choice(len(X), self.k, replace=False)]
for i in range(self.max_iter):
clusters = [[] for _ in range(self.k)]
for x in X:
distances = [np.linalg.norm(x - c) for c in self.centroids]
cluster_idx = np.argmin(distances)
clusters[cluster_idx].append(x)
new_centroids = []
for i, cluster in enumerate(clusters):
if cluster:
new_centroid = np.mean(cluster, axis=0)
new_centroids.append(new_centroid)
else:
new_centroids.append(self.centroids[i])
if np.allclose(self.centroids, new_centroids):
break
self.centroids = new_centroids
def predict(self, X):
distances = np.array([np.linalg.norm(X - c, axis=1) for c in self.centroids])
return np.argmin(distances, axis=0)
# 生成随机数据
np.random.seed(0)
X = np.random.randn(200, 2)
# 聚类
model = KMeans(k=3)
model.fit(X)
y = model.predict(X)
# 可视化
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.scatter(model.centroids[:, 0], model.centroids[:, 1], marker='*', s=200, c='r')
plt.show()
```
该实现中,KMeans类包含两个方法:fit和predict。fit方法用于训练模型,而predict方法用于对新数据进行预测。在fit方法中,首先随机选择k个中心点,然后迭代更新中心点和聚类结果,直到中心点不再变化或达到最大迭代次数为止。在predict方法中,通过计算新数据点与中心点的距离来确定其所属的聚类。最后,通过matplotlib库可视化聚类结果。