pytorch 取出tensor中大于某个值的索引
时间: 2023-12-30 10:16:33 浏览: 314
详解PyTorch中Tensor的高阶操作
可以使用 PyTorch 中的 `torch.nonzero()` 函数来获取一个 tensor 中大于某个值的索引。示例如下:
```python
import torch
# 创建一个 3x3 的 tensor
x = torch.tensor([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]])
# 获取大于 0.5 的元素的索引
indices = torch.nonzero(x > 0.5)
print(indices)
```
输出结果为:
```
tensor([[1, 2],
[2, 0],
[2, 1],
[2, 2]])
```
其中,输出结果的每一行表示一个符合条件的元素的索引,第一列表示行索引,第二列表示列索引。
阅读全文