torch.maxx
时间: 2023-12-02 13:37:34 浏览: 42
浅谈pytorch torch.backends.cudnn设置作用
5星 · 资源好评率100%
torch.maxx 是 PyTorch 中的一个函数,用于返回输入张量每行的最大值,并可以指定沿着哪一个维度进行计算。具体用法如下:
```python
torch.maxx(input, dim=None, keepdim=False, out=None) -> Tensor
```
其中,input 表示输入的张量;dim 表示进行计算的维度;keepdim 表示是否保留维度;out 表示输出的张量。
例如,对于一个大小为 (3, 4) 的张量,我们可以沿着第 1 个维度计算每行的最大值,代码如下:
```python
import torch
x = torch.tensor([[1, 5, 3, 4],
[2, 6, 8, 0],
[7, 9, 4, 2]])
max_values, _ = torch.max(x, dim=1)
print(max_values) # tensor([5, 8, 9])
```
阅读全文