python实现不使用scikit-learn工具包,实现一个均值聚类算法 * 使用该算法将鸢尾花数据集聚成三类,并可视化聚类结果 * 改变聚类的簇数,并可视化聚类结果
时间: 2024-12-21 21:26:40 浏览: 4
在Python中,我们可以使用KMeans算法(基于距离的简单聚类方法)来实现一个基本的均值聚类。由于sklearn库已经包含了一个非常高效的KMeans实现,为了不使用它,我们将手动编写KMeans的基本步骤,包括初始化质心、迭代更新质心直到收敛等。这里我们先不使用鸢尾花数据集,因为它是sklearn自带的数据集,但我们先演示一个通用的例子。
首先,我们需要创建一个随机初始化中心点的函数:
```python
import numpy as np
def initialize_centroids(X, k):
n_samples, _ = X.shape
centroids = X[np.random.choice(range(n_samples), k, replace=False)]
return centroids
```
然后,我们定义一个函数来进行KMeans迭代:
```python
def kmeans(X, k, max_iters=100, tol=1e-4):
centroids = initialize_centroids(X, k)
prev_centroids = None
for iter in range(max_iters):
# Assign each point to the closest centroid
assignments = np.argmin(cdist(X, centroids), axis=1)
# Update centroids based on new assignment
new_centroids = np.array([X[assignments == i].mean(axis=0) for i in range(k)])
# Check convergence
if (prev_centroids is not None and np.allclose(new_centroids, prev_centroids, rtol=tol)):
break
prev_centroids = new_centroids
return assignments, new_centroids
```
`cdist`函数用于计算样本与质心之间的欧氏距离,这里可以自行实现或者使用其他库如`scipy.spatial.distance.cdist`。现在我们有了基本的KMeans算法,可以应用于任何数据集。
如果你有特定的数据集(不是鸢尾花),你可以将其转换为numpy数组并调用上面的函数。至于可视化,你可以使用matplotlib或者seaborn库来绘制散点图,每个颜色代表一类:
```python
import matplotlib.pyplot as plt
from matplotlib.collections import EllipseCollection
from scipy.spatial import ConvexHull
def plot_clusters(X, assignments, centroids, ax=None):
if ax is None:
fig, ax = plt.subplots()
colors = ['red', 'green', 'blue']
for i in range(len(colors)):
data_points = X[assignments == i]
ax.scatter(data_points[:, 0], data_points[:, 1], c=colors[i], label=f'Cluster {i+1}')
# Plot centroids as ellipses
hulls = [ConvexHull(point) for point in centroids]
ellipse_args = [{'facecolor': color, 'edgecolor': 'k', 'alpha': 0.5} for color in colors]
ax.add_collection(EllipseCollection(hulls=[hull.points for hull in hulls], **ellipse_args))
ax.legend()
ax.set_title('KMeans Clustering Result')
# 示例数据和可视化工厂
X = ... # 用户自定义的数据
assignments, centroids = kmeans(X, k=3)
plot_clusters(X, assignments, centroids)
plt.show()
```
改变聚类簇数只需将`k`参数调整为你想要的类别数即可。同样的,你需要替换`X`变量为实际的数据,比如鸢尾花数据集的特征矩阵。
阅读全文