TORCH.MASKED介绍
时间: 2023-08-27 19:15:56 浏览: 42
`torch.masked_select(input, mask, out=None)` 是 PyTorch 中的一个函数,用于从输入张量 `input` 中提取满足条件的元素。其中,条件由布尔张量 `mask` 定义,`mask` 中为 `True` 的位置对应的 `input` 中的元素会被提取。该函数返回一个一维张量,包含所有满足条件的元素,元素顺序与在 `input` 张量中的顺序相同。
`torch.masked_select()` 的使用示例:
``` python
import torch
# 创建一个张量
x = torch.randn(3, 4)
print(x)
# tensor([[-0.1652, 0.1042, -0.1578, -0.6508],
# [-1.1958, 0.9910, 0.2753, -0.2862],
# [-0.3186, 0.1484, -1.7247, -1.1172]])
# 创建一个布尔张量
mask = x > 0
print(mask)
# tensor([[False, True, False, False],
# [False, True, True, False],
# [False, True, False, False]])
# 提取满足条件的元素
result = torch.masked_select(x, mask)
print(result)
# tensor([0.1042, 0.9910, 0.2753, 0.1484])
```
上述代码中,首先创建了一个大小为 $3\times 4$ 的张量 `x` 和一个与 `x` 大小相同的布尔张量 `mask`,其中 `mask` 中为 `True` 的位置对应的 `x` 中的元素会被提取。然后,使用 `torch.masked_select()` 函数从 `x` 中提取满足条件的元素,结果保存在 `result` 中。最终输出结果为一个一维张量,包含所有满足条件的元素。