masked_fill
时间: 2023-09-03 10:07:04 浏览: 98
pytorch masked_fill报错的解决
`masked_fill` 是 PyTorch 中的一个操作,它可以对一个张量进行操作,并根据指定的掩码(mask)在特定位置填充指定的值。掩码是一个跟原始张量形状相同的张量,其中的元素是 0 或 1,表示哪些位置需要被填充,哪些位置不需要被填充。通常情况下,掩码中的 0 表示不需要填充,1 表示需要填充。
例如,假设我们有如下张量和掩码:
```python
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]])
```
我们想要将 `x` 中的所有奇数位置填充为 -1,可以使用 `masked_fill` 操作:
```python
x.masked_fill(mask == 1, -1)
```
操作的结果是:
```
tensor([[ 1, -1, 3],
[-1, 5, -1],
[ 7, -1, 9]])
```
可以看到,`x` 中第 1、3、5、7 个位置是奇数,对应的掩码中的值为 1,因此在这些位置上填充了 -1。
阅读全文