torch.distributions.Normal
时间: 2023-07-18 14:15:33 浏览: 260
PyTorch的torch.cat用法
5星 · 资源好评率100%
torch.distributions.Normal 是 PyTorch 中的一个概率分布类,用于表示正态分布。它可以用来生成符合正态分布的随机数,也可以计算正态分布的概率密度函数值、累积分布函数值等等。
在 PyTorch 中,可以通过创建一个 Normal 对象来表示一个正态分布。创建 Normal 对象时需要指定均值和标准差,例如:
```
import torch
from torch.distributions.normal import Normal
mu = torch.tensor([0.0])
sigma = torch.tensor([1.0])
normal = Normal(mu, sigma)
```
这个例子中,我们定义了一个均值为 0,标准差为 1 的正态分布。我们可以使用 `sample()` 方法来生成一个符合该分布的随机数:
```
sample = normal.sample()
```
我们也可以计算该分布的概率密度函数值:
```
pdf = normal.log_prob(sample)
```
这里的 `pdf` 是一个张量,其形状与 `sample` 相同,每个元素表示该随机数在该正态分布下的概率密度函数值。
阅读全文