torch.nonzero
时间: 2023-07-03 14:15:47 浏览: 159
`torch.nonzero()` 是一个 PyTorch 中的函数,它用于返回一个张量中非零元素的索引。这个函数的使用方法如下:
```
torch.nonzero(input, *, out=None) -> LongTensor
```
其中,`input` 表示输入的张量,`out` 表示输出的张量。函数返回一个二维的张量,其中每行表示一个非零元素的索引,每列表示一个维度的索引。例如,对于一个形状为 `(3, 4)` 的张量,`torch.nonzero(tensor)` 将返回一个形状为 `(n, 2)` 的张量,其中 `n` 表示非零元素的个数,每行表示一个非零元素的索引,第一列为行索引,第二列为列索引。
下面是一个例子:
```
import torch
x = torch.tensor([[0, 1, 0, 0],
[2, 0, 0, 3],
[0, 0, 4, 0]])
print(torch.nonzero(x)) # 输出 tensor([[0, 1], [1, 0], [1, 3], [2, 2]])
```
这个例子中,输入的张量 `x` 包含了一些非零元素。`torch.nonzero(x)` 返回一个二维张量,其中每行表示一个非零元素的索引。具体来说,第一个非零元素为 `1`,位于第 `0` 行第 `1` 列,第二个非零元素为 `2`,位于第 `1` 行第 `0` 列,以此类推。因此,`torch.nonzero(x)` 返回了一个形状为 `(4, 2)` 的张量。
阅读全文