torch.bernoulli
时间: 2024-01-11 20:02:56 浏览: 99
torch.bernoulli is a function in PyTorch that generates samples from a Bernoulli distribution with a given probability of success (or "1" outcome). The function takes a tensor of probabilities as input and returns a tensor of binary values (0 or 1) sampled according to the Bernoulli distribution with the given probabilities.
For example, if we wanted to generate a tensor of 10 binary values with a probability of success of 0.7, we could use the following code:
```
import torch
probs = torch.tensor([0.7]) # probability of success
samples = torch.bernoulli(probs.repeat(10)) # generate 10 samples
print(samples)
```
This would output a tensor of 10 binary values sampled from a Bernoulli distribution with a probability of success of 0.7:
```
tensor([1., 1., 1., 0., 1., 0., 1., 1., 1., 1.])
```
阅读全文