import numpy as np from numpy import int64 def bilinear_interp_baseline(a: np.ndarray, b: np.ndarray) -> np.ndarray: """ This is the baseline implementation of bilinear interpolation without vectorization. - a is a ND array with shape [N, H1, W1, C], dtype = int64 - b is a ND array with shape [N, H2, W2, 2], dtype = float64 - return a ND array with shape [N, H2, W2, C], dtype = int64 """ # Get axis size from ndarray shape N, H1, W1, C = a.shape N1, H2, W2, _ = b.shape assert N == N1 # Do iteration res = np.empty((N, H2, W2, C), dtype=int64) for n in range(N): for i in range(H2): for j in range(W2): x, y = b[n, i, j] x_idx, y_idx = int(np.floor(x)), int(np.floor(y)) _x, _y = x - x_idx, y - y_idx # For simplicity, we assume all x are in [0, H1 - 1), all y are in [0, W1 - 1) res[n, i, j] = a[n, x_idx, y_idx] * (1 - _x) * (1 - _y) + a[n, x_idx + 1, y_idx] * _x * (1 - _y) + \ a[n, x_idx, y_idx + 1] * (1 - _x) * _y + a[n, x_idx + 1, y_idx + 1] * _x * _y return res 将这段代码改成向量化的方式
时间: 2024-04-08 09:35:19 浏览: 155
下面是将代码改写为向量化方式的版本:
```python
import numpy as np
from numpy import int64
def bilinear_interp_vectorized(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""
This is the vectorized implementation of bilinear interpolation.
- a is a ND array with shape [N, H1, W1, C], dtype = int64
- b is a ND array with shape [N, H2, W2, 2], dtype = float64
- return a ND array with shape [N, H2, W2, C], dtype = int64
"""
# Get axis size from ndarray shape
N, H1, W1, C = a.shape
N1, H2, W2, _ = b.shape
assert N == N1
# Calculate the indices for interpolation
x_idx = np.floor(b[:, :, :, 0]).astype(int)
y_idx = np.floor(b[:, :, :, 1]).astype(int)
x = b[:, :, :, 0] - x_idx
y = b[:, :, :, 1] - y_idx
# Calculate the interpolation values
res = a[:, x_idx, y_idx] * (1 - x) * (1 - y) + a[:, x_idx + 1, y_idx] * x * (1 - y) + \
a[:, x_idx, y_idx + 1] * (1 - x) * y + a[:, x_idx + 1, y_idx + 1] * x * y
return res.astype(int64)
```
这个向量化的版本通过使用数组切片和广播操作,避免了显式的循环,从而提高了代码的执行效率。
阅读全文