torch.distributions
时间: 2023-10-31 11:58:15 浏览: 49
torch.distributions 是 PyTorch 中用于概率分布建模和随机采样的模块。它提供了各种常见的概率分布,如正态分布、均匀分布、贝塔分布等,并支持对这些分布进行采样、计算概率密度函数(PDF)、累积分布函数(CDF)等操作。
这个模块非常有用,可以用于构建生成模型、变分自编码器、深度强化学习等多个领域的应用。通过使用 torch.distributions,你可以方便地进行概率建模,并使用 PyTorch 的自动求导功能进行模型训练。
你还有其他关于 torch.distributions 的问题吗?
相关问题
torch.distributions.Normal
torch.distributions.Normal 是 PyTorch 中的一个概率分布类,用于表示正态分布。它可以用来生成符合正态分布的随机数,也可以计算正态分布的概率密度函数值、累积分布函数值等等。
在 PyTorch 中,可以通过创建一个 Normal 对象来表示一个正态分布。创建 Normal 对象时需要指定均值和标准差,例如:
```
import torch
from torch.distributions.normal import Normal
mu = torch.tensor([0.0])
sigma = torch.tensor([1.0])
normal = Normal(mu, sigma)
```
这个例子中,我们定义了一个均值为 0,标准差为 1 的正态分布。我们可以使用 `sample()` 方法来生成一个符合该分布的随机数:
```
sample = normal.sample()
```
我们也可以计算该分布的概率密度函数值:
```
pdf = normal.log_prob(sample)
```
这里的 `pdf` 是一个张量,其形状与 `sample` 相同,每个元素表示该随机数在该正态分布下的概率密度函数值。
torch.distributions.Categorical
torch.distributions.categorical是PyTorch中的一个概率分布模块,用于生成分类分布。
该模块包含了一个Categorical类,可以用来创建分类分布对象。分类分布用于生成从一组离散概率分布中选择的随机样本。Categorical类的构造函数需要一个1-D张量probs,其中每个元素都是该类别被选中的概率。可以使用这个类的sample()方法生成从这个分布中采样的值,或者使用log_prob()方法计算一个或多个给定值的对数概率。
例如,以下代码创建了一个包含3个类别的分类分布,其中第一个类别的概率为0.5,第二个和第三个类别的概率分别为0.25:
```python
import torch
probs = torch.tensor([0.5, 0.25, 0.25])
categorical_dist = torch.distributions.categorical.Categorical(probs=probs)
```
可以使用sample()方法从分类分布中生成一个样本:
```python
sample = categorical_dist.sample()
```
可以使用log_prob()方法计算样本的对数概率:
```python
log_prob = categorical_dist.log_prob(sample)
```
这里的样本是一个从分类分布中随机生成的整数,它的值介于0和2之间,对应于分布中的三个类别。