手写KMeans如果要调参的话,调哪些参数,写出代码
时间: 2024-03-21 19:43:09 浏览: 51
手写KMeans需要调整的参数包括聚类的数量k,初始化质心的方法,以及停止迭代的条件等。下面是一个简单的KMeans实现,其中包含了一些可调参数的默认值:
```python
import numpy as np
class KMeans:
def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, init_method='random'):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.init_method = init_method
def fit(self, X):
# 初始化质心
if self.init_method == 'random':
self.centroids = X[np.random.choice(len(X), self.n_clusters, replace=False)]
elif self.init_method == 'kmeans++':
self.centroids = [X[np.random.choice(len(X))]]
for i in range(self.n_clusters - 1):
distances = np.array([np.min([np.linalg.norm(x-c)**2 for c in self.centroids]) for x in X])
probs = distances / np.sum(distances)
self.centroids.append(X[np.random.choice(len(X), p=probs)])
else:
raise ValueError('Invalid init_method')
for i in range(self.max_iter):
# 计算每个样本属于哪个聚类
labels = np.argmin(np.linalg.norm(X[:, np.newaxis] - self.centroids, axis=2), axis=1)
# 更新质心
new_centroids = np.array([X[labels == j].mean(axis=0) for j in range(self.n_clusters)])
# 判断是否收敛
if np.linalg.norm(new_centroids - self.centroids) < self.tol:
break
self.centroids = new_centroids
```
其中,可以调整的参数包括:
- `n_clusters`:聚类的数量,默认为8。
- `max_iter`:最大迭代次数,默认为300。
- `tol`:停止迭代的阈值,当新旧质心之间的距离小于该值时停止迭代,默认为1e-4。
- `init_method`:初始化质心的方法,可以是'random'(随机选择数据点作为质心)或'kmeans++'(使用KMeans++算法选择质心),默认为'random'。
例如,如果要将聚类数量调整为10,可以这样使用:
```python
kmeans = KMeans(n_clusters=10)
kmeans.fit(X)
```
阅读全文