torch.normal()
时间: 2024-05-09 10:14:30 浏览: 153
torch.mean()
`torch.normal()`是PyTorch中的一个函数,用于从正态分布(也称为高斯分布)中采样随机数。它的语法如下:
```
torch.normal(mean, std, out=None)
```
其中,`mean`是正态分布的均值,`std`是正态分布的标准差。可以将`mean`和`std`设置为标量或张量。`out`是可选的输出张量,用于存储采样结果。
例如,以下代码将从均值为0、标准差为1的正态分布中采样一个大小为`(3,2)`的张量:
```
import torch
mean = 0
std = 1
size = (3,2)
x = torch.normal(mean, std, size)
print(x)
```
输出:
```
tensor([[-0.3435, -0.3662],
[-1.3872, 0.0586],
[ 0.5145, -0.9186]])
```
阅读全文