用Python写一个简易的K-means聚类算法
时间: 2024-04-12 17:08:59 浏览: 89
python实现k-means聚类算法
好的,下面是一个简易的 K-means 聚类算法的 Python 实现:
```python
import numpy as np
class KMeans:
def __init__(self, n_clusters=8, max_iter=300, tol=1e-4):
self.n_clusters = n_clusters # 聚类个数
self.max_iter = max_iter # 最大迭代次数
self.tol = tol # 迭代停止阈值
def fit(self, X):
self.centroids = X[np.random.choice(X.shape[0], self.n_clusters, replace=False)]
for i in range(self.max_iter):
clusters = [[] for _ in range(self.n_clusters)]
for x in X:
dists = np.linalg.norm(self.centroids - x, axis=1)
cluster_idx = np.argmin(dists)
clusters[cluster_idx].append(x)
new_centroids = np.array([np.mean(cluster, axis=0) for cluster in clusters])
if np.linalg.norm(new_centroids - self.centroids) < self.tol:
break
self.centroids = new_centroids
def predict(self, X):
dists = np.linalg.norm(self.centroids - X[:, np.newaxis], axis=2)
return np.argmin(dists, axis=1)
```
上述代码中,我们首先定义了一个 `KMeans` 类,它包含三个超参数:`n_clusters` 表示聚类个数,`max_iter` 表示最大迭代次数,`tol` 表示迭代停止阈值。接着我们定义了 `fit` 方法,它接受一个数据集 `X`,并在数据集上执行 K-means 聚类算法。具体来说,我们首先随机初始化聚类中心,然后执行若干次迭代,每次迭代分别计算每个样本点到每个聚类中心的距离,将每个样本点分配到距离最近的聚类中心所对应的簇中,并根据分配结果更新聚类中心。当聚类中心不再变化或者达到最大迭代次数时,算法停止。最后我们定义了 `predict` 方法,它接受一个数据集 `X`,并返回每个样本点所属的簇的编号。
阅读全文