import numpy as np from numpy import int64 def bilinear_interp_vectorized(a: np.ndarray, b: np.ndarray) -> np.ndarray: """ This is the vectorized implementation of bilinear interpolation. - a is a ND array with shape [N, H1, W1, C], dtype = int64 - b is a ND array with shape [N, H2, W2, 2], dtype = float64 - return a ND array with shape [N, H2, W2, C], dtype = int64 """ # Get axis size from ndarray shape N, H1, W1, C = a.shape N1, H2, W2, _ = b.shape assert N == N1 # Calculate the indices for interpolation x_idx = np.floor(b[:, :, :, 0]).astype(int) y_idx = np.floor(b[:, :, :, 1]).astype(int) x = b[:, :, :, 0] - x_idx y = b[:, :, :, 1] - y_idx # Calculate the interpolation values res = a[:, x_idx, y_idx] * (1 - x) * (1 - y) + a[:, x_idx + 1, y_idx] * x * (1 - y) + a[:, x_idx, y_idx + 1] * (1 - x) * y + a[:, x_idx + 1, y_idx + 1] * x * y return res.astype(int64) 在res = a[:, x_idx, y_idx] * (1 - x) * (1 - y) + a[:, x_idx + 1, y_idx] * x * (1 - y) + a[:, x_idx, y_idx + 1] * (1 - x) * y + a[:, x_idx + 1, y_idx + 1] * x * y发生错误operands could not be broadcast together with shapes (8,8,720,1280,4) (8,720,1280)
时间: 2024-04-08 07:35:22 浏览: 48
numpy-1.22.4+vanilla-cp311-cp311-win-amd64-whl
这个错误是由于形状不匹配导致的。你可以在计算 `res` 之前,先调整 `x_idx` 和 `y_idx` 的形状,使其与 `a` 的形状相匹配。可以使用 `np.expand_dims` 函数来添加额外的维度。修改代码如下:
```python
x_idx = np.expand_dims(x_idx, axis=-1)
y_idx = np.expand_dims(y_idx, axis=-1)
```
然后再重新计算 `res`。
希望这可以解决你的问题!如果还有其他问题,请随时提问。
阅读全文