torch.mean如何使用
时间: 2023-12-29 12:49:20 浏览: 102
torch.mean()
torch.mean函数用于计算张量的平均值,其语法如下:
```
torch.mean(input, dim=None, keepdim=False, out=None)
```
参数说明:
- input:输入的张量
- dim:指定计算平均值的维度。如果不指定,则计算所有元素的平均值。
- keepdim:是否保持输出张量的维度和输入张量一致。默认为False。
- out:输出张量,可选参数。
示例代码:
```
import torch
# 创建一个2x3的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
# 计算所有元素的平均值
mean = torch.mean(x)
print(mean)
# 沿着行方向计算平均值
mean_row = torch.mean(x, dim=1)
print(mean_row)
# 沿着列方向计算平均值
mean_col = torch.mean(x, dim=0)
print(mean_col)
```
输出结果为:
```
tensor(3.5000)
tensor([2., 5.])
tensor([2.5000, 3.5000, 4.5000])
```
阅读全文