torch mean
时间: 2023-10-21 07:24:58 浏览: 140
torch.mean()
`torch.mean` 是 PyTorch 中的一个函数,用于计算给定张量在指定维度上的平均值。它的语法如下:
```python
torch.mean(input, dim=None, keepdim=False, out=None)
```
参数说明:
- `input`:输入张量。
- `dim`:要沿着哪个维度计算平均值。如果不指定,则计算整个张量的平均值。
- `keepdim`:是否保持输出张量的维度和输入张量一致,默认为 False。
- `out`:输出张量,可选参数。
示例使用:
```python
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 计算整个张量的平均值
mean_all = torch.mean(x)
print(mean_all) # 输出: tensor(3.5000)
# 沿着第 0 维计算平均值
mean_dim0 = torch.mean(x, dim=0)
print(mean_dim0) # 输出: tensor([2.5000, 3.5000, 4.5000])
# 沿着第 1 维计算平均值,并保持维度
mean_dim1_keepdim = torch.mean(x, dim=1, keepdim=True)
print(mean_dim1_keepdim) # 输出: tensor([[2.],
# [5.]])
```
以上是使用 `torch.mean` 函数计算平均值的示例。希望能对你有所帮助!如果有更多问题,请随时提问。
阅读全文