pytorch实现k均值算法,并解释代码
时间: 2023-12-06 11:39:49 浏览: 74
基于GPU加速+Pytorch的K-Means聚类实现-附项目源码-优质项目实战.zip
5星 · 资源好评率100%
以下是基于 PyTorch 实现 k 均值算法的代码及解释:
```python
import torch
def kmeans(X, K, max_iters=100):
"""
X: 数据集,shape为 [n_samples, n_features]
K: 聚类数量
max_iters: 最大迭代次数
"""
n_samples = X.shape[0]
# 初始化聚类中心,随机选择K个样本为初始聚类中心
centroids_idx = torch.randint(0, n_samples, (K,))
centroids = X[centroids_idx]
# 初始化每个样本所属的聚类
cluster_assignments = torch.zeros(n_samples, dtype=torch.long)
for i in range(max_iters):
# 计算每个样本到每个聚类中心的距离
distances = torch.cdist(X, centroids)
# 分配样本到最近的聚类中心
new_assignments = torch.argmin(distances, dim=1)
# 如果新的聚类分配和旧的一样,则已经收敛
if torch.all(new_assignments == cluster_assignments):
break
cluster_assignments = new_assignments
# 更新聚类中心
for j in range(K):
centroids[j] = X[cluster_assignments == j].mean(dim=0)
return cluster_assignments, centroids
```
该函数接受如下参数:
- `X`:数据集,形状为 `[n_samples, n_features]`。
- `K`:聚类数量。
- `max_iters`:最大迭代次数,默认为 100。
函数首先初始化聚类中心,随机选择 K 个样本为初始聚类中心,并初始化每个样本所属的聚类。然后,它开始迭代,每次迭代包括以下步骤:
1. 计算每个样本到每个聚类中心的距离,使用 `torch.cdist` 函数实现。
2. 分配样本到最近的聚类中心,使用 `torch.argmin` 函数实现。
3. 如果新的聚类分配和旧的一样,则已经收敛,跳出迭代。
4. 更新聚类中心,对于每个聚类,计算该聚类中所有样本的平均值并将其作为新的聚类中心。
最后返回每个样本所属的聚类和最终的聚类中心。
这里使用 PyTorch 的张量操作实现了 k 均值算法,可以在 GPU 上加速执行。
阅读全文