torch randint
时间: 2023-11-01 19:22:12 浏览: 92
`torch.randint()`是PyTorch库中的一个函数,用于生成随机整数张量。它的使用方法如下:
```python
torch.randint(low, high, size, dtype=None, layout=torch.strided, device=None, requires_grad=False, generator=None)
```
其中:
- `low`表示生成的随机整数的下界(包含在范围内)。
- `high`表示生成的随机整数的上界(不包含在范围内)。
- `size`表示生成的张量的形状。
- `dtype`表示生成的张量的数据类型,默认为None,即自动推断。
- `layout`表示生成的张量的布局,默认为torch.strided。
- `device`表示生成的张量所在的设备,默认为None,即使用当前设备。
- `requires_grad`表示生成的张量是否需要梯度,默认为False,即不需要梯度。
- `generator`表示用于生成随机数的生成器,默认为None,即使用全局生成器。
例如,我们可以使用如下代码生成一个形状为`(2, 3)`、取值范围在0到9之间(不包含9)的随机整数张量:
```python
import torch
rand_int = torch.randint(0, 9, (2, 3))
print(rand_int)
```
这将会输出类似以下的结果:
```
tensor([[8, 2, 5],
[6, 7, 1]])
```
希望这能解答你的问题!如果还有其他问题,请继续提问。
阅读全文