[[3.1641e-01, 3.6478e-04, 3.6478e-04, 3.1641e-01, 1.1563e-05, 3.4451e-05, 3.1641e-01, 1.5557e-05, 4.5593e-02, 2.2829e-04, 8.5831e-06, 1.2803e-04, 3.7422e-03, 1.6701e-04, 2.0564e-05, 1.3423e-04], [4.0703e-03, 8.7036e-02, 8.7036e-02, 4.0703e-03, 1.2517e-06, 1.7405e-05, 4.0703e-03, 1.2338e-05, 8.1299e-01, 1.3089e-04, 9.8765e-05, 3.8087e-05, 5.2869e-05, 1.3924e-04, 3.8743e-06, 2.7275e-04], [4.6272e-03, 1.4844e-01, 1.4844e-01, 4.6272e-03, 9.5367e-06, 1.7464e-05, 4.6272e-03, 8.8811e-05, 6.8652e-01, 2.8658e-04, 3.5119e-04, 9.4533e-05, 1.2326e-04, 7.3195e-04, 5.7340e-05, 7.6723e-04]]这样一个张量,如何对其中的每一个维度上的列表选取最大值的索引
时间: 2023-07-22 15:03:55 浏览: 165
可以使用 PyTorch 中的 `argmax` 函数来对张量的每一个维度上的列表选取最大值的索引,具体代码如下:
```python
import torch
# 定义一个张量
tensor = torch.tensor([[3.1641e-01, 3.6478e-04, 3.6478e-04, 3.1641e-01, 1.1563e-05, 3.4451e-05, 3.1641e-01, 1.5557e-05, 4.5593e-02, 2.2829e-04, 8.5831e-06, 1.2803e-04, 3.7422e-03, 1.6701e-04, 2.0564e-05, 1.3423e-04],
[4.0703e-03, 8.7036e-02, 8.7036e-02, 4.0703e-03, 1.2517e-06, 1.7405e-05, 4.0703e-03, 1.2338e-05, 8.1299e-01, 1.3089e-04, 9.8765e-05, 3.8087e-05, 5.2869e-05, 1.3924e-04, 3.8743e-06, 2.7275e-04],
[4.6272e-03, 1.4844e-01, 1.4844e-01, 4.6272e-03, 9.5367e-06, 1.7464e-05, 4.6272e-03, 8.8811e-05, 6.8652e-01, 2.8658e-04, 3.5119e-04, 9.4533e-05, 1.2326e-04, 7.3195e-04, 5.7340e-05, 7.6723e-04]])
# 对每一个维度上的列表选取最大值的索引
max_indices = torch.argmax(tensor, dim=1) # 对第二个维度进行操作,即选取每个列表的最大值索引
print(max_indices) # 输出每个列表最大值的索引
```
运行结果:
```
tensor([0, 8, 8])
```
上述代码中,`argmax` 函数的参数 `dim` 指定了在哪一个维度上进行操作。对于这个张量,我们要对每个列表选取最大值的索引,因此我们要对第二个维度进行操作,即 `dim=1`。最终得到的 `max_indices` 张量的每一个元素表示对应列表最大值的索引。
阅读全文