Scikit-learn 中的 AgglomerativeClustering
时间: 2024-03-01 20:55:47 浏览: 18
AgglomerativeClustering 是 Scikit-learn 中的一个聚类算法,它属于层次聚类算法的一种,也叫做自底向上聚合算法。该算法的主要思想是将每个数据点看作一个单独的簇,然后将相似的簇合并成一个更大的簇,直到满足停止条件为止。在 AgglomerativeClustering 算法中,我们需要指定簇数或者合并的距离阈值来停止聚合过程。
AgglomerativeClustering 算法的主要参数包括:
- n_clusters:簇数,指定聚成几类,默认为 2。
- linkage:合并策略,指定合并簇的方式,包括 ward、complete、average 和 single 四种方式,默认为 ward 方式。
其中,ward 方式使用方差来度量簇的距离;complete 和 average 方式使用簇中最远点和平均点之间的距离来度量簇的距离;single 方式使用簇中最近点之间的距离来度量簇的距离。
使用 AgglomerativeClustering 算法进行聚类的主要步骤包括:
1. 加载数据集;
2. 数据预处理;
3. 构建聚类模型;
4. 训练聚类模型;
5. 预测聚类结果;
6. 可视化聚类结果。
下面是一个使用 AgglomerativeClustering 算法对鸢尾花数据集进行聚类的示例代码:
```python
# 加载数据集
from sklearn.datasets import load_iris
iris = load_iris()
# 数据预处理
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X = scaler.fit_transform(iris.data)
# 构建聚类模型
from sklearn.cluster import AgglomerativeClustering
model = AgglomerativeClustering(n_clusters=3, linkage='ward')
# 训练聚类模型
model.fit(X)
# 预测聚类结果
labels = model.labels_
# 可视化聚类结果
import matplotlib.pyplot as plt
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='rainbow')
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
plt.show()
```
在上面的代码中,我们首先使用 load_iris() 函数加载鸢尾花数据集,然后使用 StandardScaler() 函数进行数据标准化处理。接着,我们使用 AgglomerativeClustering() 函数构建聚类模型,其中 n_clusters=3 表示将数据聚成 3 类,linkage='ward' 表示使用 ward 方式进行合并簇的方式。然后,我们使用 fit() 函数训练聚类模型,使用 labels_ 属性获取聚类结果,最后使用 Matplotlib 库展示聚类结果。