def cal_frechet_distance(curve_a: np.ndarray, curve_b: np.ndarray): def euc_dist(pt1, pt2): return np.sqrt(np.square(pt2[0] - pt1[0]) + np.square(pt2[1] - pt1[1])) # 用递归方式计算,遍历整个ca矩阵 def _c(ca, i, j, P, Q): if ca[i, j] > -1: return ca[i, j] elif i == 0 and j == 0: ca[i, j] = euc_dist(P[0], Q[0]) elif i > 0 and j == 0: ca[i, j] = max(_c(ca, i - 1, 0, P, Q), euc_dist(P[i], Q[0])) elif i == 0 and j > 0: ca[i, j] = max(_c(ca, 0, j - 1, P, Q), euc_dist(P[0], Q[j])) elif i > 0 and j > 0: ca[i, j] = max(min(_c(ca, i - 1, j, P, Q), _c(ca, i - 1, j - 1, P, Q), _c(ca, i, j - 1, P, Q)), euc_dist(P[i], Q[j])) else: ca[i, j] = float("inf") return ca[i, j] def frechet_distance(P, Q): ca = np.ones((len(P), len(Q))) ca = np.multiply(ca, -1) dis = _c(ca, len(P) - 1, len(Q) - 1, P, Q) # ca为全-1的矩阵,shape = ( len(a), len(b) ) return dis # 构造计算序列 curve_line_a = list(zip(range(len(curve_a)), curve_a)) curve_line_b = list(zip(range(len(curve_b)), curve_b)) return frechet_distance(curve_line_a, curve_line_b) 代码 优化
时间: 2023-11-29 12:06:49 浏览: 88
这段代码可以进行一些优化,如下所示:
1. 使用 scipy.spatial.distance.cdist 函数来计算欧几里得距离,可以代替 euc_dist 函数,代码如下:
```
from scipy.spatial.distance import cdist
def frechet_distance(P, Q):
ca = cdist(P, Q)
ca = np.pad(ca, ((1, 0), (1, 0)), 'constant', constant_values=np.inf)
ca[0, 0] = 0
for i in range(1, len(P) + 1):
for j in range(1, len(Q) + 1):
ca[i, j] = max(min(ca[i-1, j], ca[i-1, j-1], ca[i, j-1]), ca[i, j])
return ca[-1, -1]
```
2. 可以使用 numpy 的广播机制来代替循环,代码如下:
```
def frechet_distance(P, Q):
ca = cdist(P, Q)
ca = np.pad(ca, ((1, 0), (1, 0)), 'constant', constant_values=np.inf)
ca[0, 0] = 0
ca[1:, 1:] = np.maximum(np.minimum(ca[:-1, 1:], ca[:-1, :-1], ca[1:, :-1]), ca[1:, 1:])
return ca[-1, -1]
```
阅读全文