pytorch 取出tensor中大于某个值的索引
时间: 2023-12-30 22:16:32 浏览: 337
使用pytorch 筛选出一定范围的值
可以使用 pytorch 中的函数 `torch.where()` 来取出 tensor 中大于某个值的索引。具体步骤如下:
1. 首先使用 `torch.gt()` 函数来比较 tensor 中的每个元素是否大于目标值,返回一个 bool 类型的 tensor。
2. 然后使用 `torch.nonzero()` 函数来取出 bool tensor 中非零元素的索引,即大于目标值的元素索引。
以下是示例代码:
```python
import torch
# 创建一个大小为 3x3 的 tensor
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 取出大于 5 的元素索引
indices = torch.nonzero(torch.gt(a, 5))
print(indices)
```
输出结果为:
```
tensor([[1, 2],
[2, 0],
[2, 1],
[2, 2]])
```
这表示 tensor 中第 2 行第 3 列、第 3 行第 1 列、第 3 行第 2 列、第 3 行第 3 列的元素大于 5。
阅读全文