python中类torch.mean
时间: 2023-07-03 14:09:08 浏览: 150
基于神经网络CNN&ResNet 的图像分类代码、python实现、pytorch框架
`torch.mean` 是 PyTorch 库中的一个函数,用于计算张量的平均值。如果是一个多维张量,可以通过指定 `dim` 参数来沿某些维度计算平均值。
具体来说,`torch.mean(input, dim=None, keepdim=False)` 的函数原型如下:
```
torch.mean(input, dim=None, keepdim=False) -> Tensor
```
其中:
- `input`:需要计算平均值的输入张量。
- `dim`:指定需要沿哪一些维度计算平均值,可以是一个整数或者一个元组。默认为 `None`,表示对所有元素求平均值。
- `keepdim`:指定是否保留计算平均值的维度。默认为 `False`,表示不保留。
例如,对于一个 2x3 的张量,可以通过以下代码计算其所有元素的平均值:
```python
import torch
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mean = torch.mean(x)
print(mean) # 输出为 tensor(3.5000)
```
如果要沿着第 1 维计算平均值,可以将 `dim` 参数设置为 1:
```python
import torch
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mean = torch.mean(x, dim=1)
print(mean) # 输出为 tensor([2., 5.])
```
这个例子中,`mean` 是一个 1 维张量,其中每个元素是原张量在第 1 维上的平均值。
阅读全文