torch.mean()函数详解
时间: 2024-06-08 16:06:34 浏览: 183
torch.mean()
torch.mean()函数是PyTorch库中的一个函数,用于计算给定张量的平均值。它可以接受一个张量作为输入,并返回该张量的元素的平均值。
该函数的语法如下:
```
torch.mean(input, dim, keepdim=False, *, out=None) -> Tensor
```
参数说明:
- input:输入的张量。
- dim:指定在哪个维度上计算平均值,可以是一个整数或一个元组。当`dim`是一个整数时,表示在指定维度上进行计算;当`dim`是一个元组时,表示在多个维度上进行计算。
- keepdim:指定是否保持输出张量的维度和输入张量一致,默认为False。
- out:可选参数,用于指定输出张量。
示例:
```python
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
mean = torch.mean(x)
print(mean) # 输出5.0
mean_row = torch.mean(x, dim=0)
print(mean_row) # 输出tensor([2.5, 3.5, 4.5])
mean_col = torch.mean(x, dim=1)
print(mean_col) # 输出tensor([2., 5.])
```
上述示例中,我们首先定义了一个2行3列的张量`x`,然后分别使用`torch.mean()`函数计算了整个张量、每一行和每一列的平均值。可以看到,对于整个张量来说,所有元素的平均值为5.0;对于每一行来说,分别计算了每一行中元素的平均值,返回一个包含每一列平均值的张量;对于每一列来说,同样计算了每一列中元素的平均值,返回一个包含每一行平均值的张量。
希望这个解答对你有所帮助!如有更多问题,请随时提问!
阅读全文