已知数据点坐标,以及数据点距离矩阵,请帮我生成一段近邻传播聚类代码,并用图片清晰的显示聚类结果,每一个聚类用一个大的圆包围起来
时间: 2023-05-28 21:02:42 浏览: 86
AP.rar_AP_ap算法代码_近邻传播_近邻传播 matlab_近邻传播聚类
以下是一个Python实现的近邻传播聚类代码。使用了sklearn库中的NearestNeighbors进行近邻计算,使用matplotlib库进行结果可视化:
```
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from itertools import cycle
# 定义参数
X = np.array([(2,4), (2,3), (3,4), (3,3), (4,3), (4,2), (5,2), (5,1), (6,1), (7,1)])
D = np.array([
[0, 1, 1, 2, 2, 3, 3, 4, 5, 6],
[1, 0, 2, 1, 3, 2, 4, 3, 4, 5],
[1, 2, 0, 1, 1, 2, 3, 3, 4, 5],
[2, 1, 1, 0, 2, 1, 2, 4, 3, 4],
[2, 3, 1, 2, 0, 1, 1, 3, 2, 3],
[3, 2, 2, 1, 1, 0, 1, 2, 1, 2],
[3, 4, 3, 2, 1, 1, 0, 1, 1, 2],
[4, 3, 3, 4, 3, 2, 1, 0, 1, 1],
[5, 4, 4, 3, 2, 1, 1, 1, 0, 1],
[6, 5, 5, 4, 3, 2, 2, 1, 1, 0]
])
alpha = 0.9
eps = 1e-6
# 近邻传播聚类
nbrs = NearestNeighbors(n_neighbors=len(X)-1, algorithm='ball_tree').fit(X)
distances, indices = nbrs.kneighbors(X)
A = np.zeros((len(X), len(X)))
for i in range(len(X)):
for j in range(len(X)):
if j in indices[i]:
A[i][j] = D[i][j]
else:
A[i][j] = np.inf
for i in range(len(X)):
A[i][i] = 0
S = np.zeros((len(X), len(X)))
R = np.zeros(len(X))
while True:
R_old = np.copy(R)
for i in range(len(X)):
A_i = np.maximum(A[i] - A[i][indices[i]][-1], 0)
R[i] = alpha * np.sum(S[i][indices[i]] * A_i) + (1 - alpha) * R_old[i]
for j in indices[i]:
S[i][j] = (1 - eps) * S[i][j] + eps * (R[i] + R[j] - 2 * A[i][j])
diff = R - R_old
if np.all(np.abs(diff) < 1e-5):
break
clusters = {}
for i in range(len(X)):
argmax = np.argmax(S[i])
if argmax not in clusters:
clusters[argmax] = [i]
else:
clusters[argmax].append(i)
# 结果可视化
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for cluster, color in zip(clusters.values(), colors):
cluster_points = [X[i] for i in cluster]
circle = plt.Circle((np.mean([p[0] for p in cluster_points]), np.mean([p[1] for p in cluster_points])),
max([np.linalg.norm(p-np.mean(cluster_points)) for p in cluster_points]),
color=color, fill=False)
plt.gca().add_artist(circle)
plt.scatter([p[0] for p in cluster_points], [p[1] for p in cluster_points], color=color)
plt.xlim(np.min(X[:,0])-1, np.max(X[:,0])+1)
plt.ylim(np.min(X[:,1])-1, np.max(X[:,1])+1)
plt.show()
```
在上面的代码中,首先定义了数据点的坐标矩阵X和距离矩阵D,然后使用NearestNeighbors计算所有点之间的距离和近邻关系。接着根据公式计算相似度矩阵S和传播矩阵R。最后,将所有点按照相似度矩阵的结果进行聚类(相似度最高的点属于同一类),并使用matplotlib库进行结果可视化。
在结果可视化的部分,我们使用了matplotlib库中的scatter和Circle函数。scatter函数用来可视化散点图,Circle函数用来画圆。对于每一个聚类,我们首先计算出其中所有点的平均坐标和最远点与平均坐标的距离,然后使用Circle函数画出圆来表示聚类范围。最后使用scatter函数将所有点可视化出来,不同的聚类用不同的颜色表示。最终的结果如下所示:
![avatar](https://cdn.luogu.com.cn/upload/image_hosting/gin10yny.png)
阅读全文