argmax(dim=1)
时间: 2023-11-25 17:06:18 浏览: 28
`argmax(dim=1)` is a PyTorch method that returns the indices of the maximum values along a particular dimension of a tensor.
For example, consider a tensor `x` of shape (3, 4) containing random values:
```
import torch
x = torch.rand(3, 4)
print(x)
```
Output:
```
tensor([[0.3523, 0.9463, 0.9542, 0.7699],
[0.4076, 0.3148, 0.4837, 0.4491],
[0.4619, 0.1391, 0.6055, 0.1525]])
```
If we apply `argmax(dim=1)` to this tensor, it will return the indices of the maximum values along the second dimension, which corresponds to the columns:
```
idx = x.argmax(dim=1)
print(idx)
```
Output:
```
tensor([2, 2, 2])
```
This means that the maximum values in each row of the tensor `x` are located in the third column (index 2).