def cal_frechet_distance(curve_a: np.ndarray, curve_b: np.ndarray): def euc_dist(pt1, pt2): return np.sqrt(np.square(pt2[0] - pt1[0]) + np.square(pt2[1] - pt1[1])) # 用递归方式计算,遍历整个ca矩阵 def _c(ca, i, j, P, Q): if ca[i, j] > -1: return ca[i, j] elif i == 0 and j == 0: ca[i, j] = euc_dist(P[0], Q[0]) elif i > 0 and j == 0: ca[i, j] = max(_c(ca, i - 1, 0, P, Q), euc_dist(P[i], Q[0])) elif i == 0 and j > 0: ca[i, j] = max(_c(ca, 0, j - 1, P, Q), euc_dist(P[0], Q[j])) elif i > 0 and j > 0: ca[i, j] = max(min(_c(ca, i - 1, j, P, Q), _c(ca, i - 1, j - 1, P, Q), _c(ca, i, j - 1, P, Q)), euc_dist(P[i], Q[j])) else: ca[i, j] = float("inf") return ca[i, j] def frechet_distance(P, Q): ca = np.ones((len(P), len(Q))) ca = np.multiply(ca, -1) dis = _c(ca, len(P) - 1, len(Q) - 1, P, Q) # ca为全-1的矩阵,shape = ( len(a), len(b) ) return dis # 构造计算序列 curve_line_a = list(zip(range(len(curve_a)), curve_a)) curve_line_b = list(zip(range(len(curve_b)), curve_b)) return frechet_distance(curve_line_a, curve_line_b) 代码 优化
时间: 2023-11-29 21:06:49 浏览: 51
以下是代码的优化建议:
1. 将_euc_dist函数内联到_c函数中,以避免函数调用开销。
2. 将_ca矩阵的初始化移到_c函数中,以避免每次调用frechet_distance函数都要重新初始化。
3. 将_ca矩阵的计算改为使用for循环,以避免使用递归时的函数调用开销。
4. 将_ca矩阵的索引从[i, j]改为[j, i],以便在访问矩阵元素时更容易理解。
下面是优化后的代码:
```
def cal_frechet_distance(curve_a: np.ndarray, curve_b: np.ndarray):
def _c(ca, i, j, P, Q):
if ca[i, j] > -1:
return ca[i, j]
elif i == 0 and j == 0:
ca[i, j] = np.sqrt(np.square(Q[0][0] - P[0][0]) + np.square(Q[0][1] - P[0][1]))
else:
if i > 0 and j == 0:
ca[i, j] = max(_c(ca, i - 1, 0, P, Q), np.sqrt(np.square(Q[0][0] - P[i][0]) + np.square(Q[0][1] - P[i][1])))
elif i == 0 and j > 0:
ca[i, j] = max(_c(ca, 0, j - 1, P, Q), np.sqrt(np.square(Q[j][0] - P[0][0]) + np.square(Q[j][1] - P[0][1])))
elif i > 0 and j > 0:
ca[i, j] = max(min(_c(ca, i - 1, j, P, Q), _c(ca, i - 1, j - 1, P, Q), _c(ca, i, j - 1, P, Q)),
np.sqrt(np.square(Q[j][0] - P[i][0]) + np.square(Q[j][1] - P[i][1])))
else:
ca[i, j] = float("inf")
return ca[i, j]
def frechet_distance(P, Q):
ca = np.ones((len(P), len(Q))) * -1
_c(ca, len(P) - 1, len(Q) - 1, P, Q)
return ca[-1, -1]
curve_line_a = list(zip(range(len(curve_a)), curve_a))
curve_line_b = list(zip(range(len(curve_b)), curve_b))
return frechet_distance(curve_line_a, curve_line_b)
```
阅读全文