from torch.distributions import Normal
时间: 2024-10-07 13:02:09 浏览: 45
`from torch.distributions import Normal` 这一行代码导入了PyTorch中的`Normal`类,它属于`torch.distributions`包,这是TensorFlow Probability (TFP)风格的概率分布模块。`Normal`是一个正态分布(也称为高斯分布),是机器学习中最常用的连续概率分布之一。在深度学习中,常用于生成模型(如变分自编码器)、模型的先验分布,以及计算损失函数中的似然度。
使用`Normal`的主要操作包括:
1. 创建分布实例:通过指定均值(mean)和标准差(std)来创建一个`Normal`对象,例如:
```python
normal_dist = Normal(loc=0., scale=1.)
```
这里的`loc`代表均值,`scale`代表标准差。
2. 样本生成:可以从中抽样出单个样本或整个批次的样本:
```python
sample = normal_dist.sample(sample_shape=torch.Size([100])) # 生成100个样本
```
3. 计算概率密度:对于给定的值,可以计算其对应的概率密度:
```python
value = torch.tensor([1.])
pdf_value = normal_dist.pdf(value) # 返回1.的PDF值
```
4. 定义概率函数:对于给定的样本集,返回所有样本的概率总和(在无约束条件下为1):
```python
log_prob = normal_dist.log_prob(value) # 对应于计算对数概率密度
```
阅读全文