torch.distributions.Categorical
时间: 2024-04-07 15:30:13 浏览: 175
torch.distributions.categorical是PyTorch中的一个概率分布模块,用于生成分类分布。
该模块包含了一个Categorical类,可以用来创建分类分布对象。分类分布用于生成从一组离散概率分布中选择的随机样本。Categorical类的构造函数需要一个1-D张量probs,其中每个元素都是该类别被选中的概率。可以使用这个类的sample()方法生成从这个分布中采样的值,或者使用log_prob()方法计算一个或多个给定值的对数概率。
例如,以下代码创建了一个包含3个类别的分类分布,其中第一个类别的概率为0.5,第二个和第三个类别的概率分别为0.25:
```python
import torch
probs = torch.tensor([0.5, 0.25, 0.25])
categorical_dist = torch.distributions.categorical.Categorical(probs=probs)
```
可以使用sample()方法从分类分布中生成一个样本:
```python
sample = categorical_dist.sample()
```
可以使用log_prob()方法计算样本的对数概率:
```python
log_prob = categorical_dist.log_prob(sample)
```
这里的样本是一个从分类分布中随机生成的整数,它的值介于0和2之间,对应于分布中的三个类别。
相关问题
torch.distributions.categorical
torch.distributions.categorical是PyTorch中的一个概率分布模块,用于生成分类分布。
该模块包含了一个Categorical类,可以用来创建分类分布对象。分类分布用于生成从一组离散概率分布中选择的随机样本。Categorical类的构造函数需要一个1-D张量probs,其中每个元素都是该类别被选中的概率。可以使用这个类的sample()方法生成从这个分布中采样的值,或者使用log_prob()方法计算一个或多个给定值的对数概率。
例如,以下代码创建了一个包含3个类别的分类分布,其中第一个类别的概率为0.5,第二个和第三个类别的概率分别为0.25:
```python
import torch
probs = torch.tensor([0.5, 0.25, 0.25])
categorical_dist = torch.distributions.categorical.Categorical(probs=probs)
```
可以使用sample()方法从分类分布中生成一个样本:
```python
sample = categorical_dist.sample()
```
可以使用log_prob()方法计算样本的对数概率:
```python
log_prob = categorical_dist.log_prob(sample)
```
这里的样本是一个从分类分布中随机生成的整数,它的值介于0和2之间,对应于分布中的三个类别。
class torch.distributions.Categorical(probs)
`torch.distributions.Categorical(probs)`是PyTorch中的一个概率分布类,用于表示具有离散取值的随机变量。它的参数`probs`是一个概率张量,表示每个离散取值的概率。
该类可以用于模拟离散随机变量的采样、计算随机变量的概率密度函数(probability density function, PDF)、计算随机变量的对数概率密度函数(log probability density function, log PDF)等操作。
例如,可以使用`Categorical`类来模拟抛硬币的结果,其中正面出现的概率为0.6,反面出现的概率为0.4:
```python
import torch
from torch.distributions import Categorical
probs = torch.tensor([0.6, 0.4])
dist = Categorical(probs)
# 采样一个随机变量
sample = dist.sample()
print(sample) # 输出0或1,表示正面或反面
# 计算随机变量的概率密度函数
pdf = dist.log_prob(sample)
print(pdf) # 输出一个张量,表示正面或反面的对数概率
```
以上代码中,`sample`表示一次抛硬币的结果,它的值为0或1,对应正面或反面。`pdf`表示抛硬币结果对应的对数概率。
阅读全文