# 计算二维平面上点的距离 def dis(a,b): return math.sqrt((a[0]-b[0])*(a[0]-b[0])+(a[1]-b[1])*(a[1]-b[1])) # K-Means # 一般执行10次以内即可完成分类 for p in range(10): # NewPoint初始化为0 NewPoint = [[0, 0] for i in range(K)] for i in range(len(X)): mDis=1e9 mC=0 for j in range(len(CentPoint)): cp=CentPoint[j] D = dis([X[i].item(), y[i].item()], cp) # print("distance:", D) if mDis>D: mDis=D mC=j C[i]=mC # print("mC",mC,C[i].item()) NewPoint[mC][0]+=X[i].item() NewPoint[mC][1]+=y[i].item() # 更新中心点 for i in range(K): CentPoint[i][0]=NewPoint[i][0]/2000 CentPoint[i][1]=NewPoint[i][1]/2000 # 输出中心点,观察变化过程 print(CentPoint)
时间: 2024-03-31 16:38:29 浏览: 32
点到平面的距离1
这段代码是 K-Means 算法的核心实现部分。首先定义了一个计算两个点之间距离的函数 dis,接着进入 K-Means 算法的迭代过程,循环执行 10 次以内即可完成分类。对于每个样本点,算法会计算它到每个聚类中心的距离,然后将该点划分到距离最近的聚类中心所在的类别中。在这个过程中,C 是一个大小为 2000 的张量,记录了每个点的类别归属。然后,将所有属于同一个类别的点的坐标相加,求出新的聚类中心并更新 CentPoint。最后输出更新后的聚类中心。在这个过程中,K-Means 算法通过不断迭代来调整聚类中心,直到聚类中心不再发生变化或者达到最大迭代次数为止。
阅读全文