如何返回x = torch.tensor([[1, 2, 3][1,4,5]]) 中4的索引(1,1)
时间: 2024-01-10 14:02:03 浏览: 48
one hot编码:`torch.Tensor.scatter_()`函数用法详解
对于给定的张量 x,可以使用 PyTorch 的 where 函数来获取所有满足给定条件的元素的索引。下面是实现:
```python
import torch
x = torch.tensor([[1, 2, 3], [1, 4, 5]])
# 使用 where 函数获取所有满足 x == 4 的元素的索引
indices = torch.where(x == 4)
# 输出元素的索引
print(indices)
```
输出结果为:
```
(tensor([1]), tensor([1]))
```
可以看到,元素 4 的索引为 (1, 1)。注意,这里返回的索引是一个元组,第一个元素是行索引,第二个元素是列索引。
阅读全文