PyTorch 中的 where 函数的代码
时间: 2024-05-13 19:14:04 浏览: 56
pytorch学习代码
PyTorch 中的 where 函数可以根据条件返回两个张量中对应位置的元素。其代码如下所示:
```python
torch.where(condition, x, y)
```
其中,`condition` 为一个布尔型张量,表示条件;`x` 和 `y` 都是张量,表示两个待选择的张量,返回的张量与 `x` 和 `y` 的形状相同。
具体使用示例如下:
```python
import torch
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([5, 6, 7, 8])
mask = torch.tensor([True, False, True, False])
c = torch.where(mask, a, b)
print(c) # tensor([1, 6, 3, 8])
```
在上面的示例中,`mask` 表示一个布尔型张量,`True` 表示选择 `a` 中对应位置的元素,`False` 表示选择 `b` 中对应位置的元素。因此,`c` 的结果为 `[1, 6, 3, 8]`。
阅读全文