torch.randint()怎么用
时间: 2024-05-03 15:21:10 浏览: 12
torch.randint()函数用于生成指定形状和范围的随机整数张量。以下是torch.randint()函数的语法和用法示例:
语法:
```python
torch.randint(high, size, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
torch.randint(low, high, size, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
```
参数说明:
- high (int):张量中随机数的上限值。
- low (int):张量中随机数的下限值。
- size (tuple):生成的随机张量的形状。
- dtype (torch.dtype, optional):生成的随机张量类型。默认为torch.int64。
- layout (torch.layout, optional):生成的随机张量布局。默认为torch.strided。
- device (torch.device, optional):生成的随机张量的设备。默认为当前设备。
- requires_grad (bool, optional):是否需要计算梯度。默认为False。
用法示例:
```python
import torch
# 生成形状为(3, 4)的随机整数张量,范围为[0, 10)
x = torch.randint(0, 10, size=(3, 4))
print(x)
# 生成形状为(2, 3, 4)的随机整数张量,范围为[0, 5)
y = torch.randint(5, size=(2, 3, 4))
print(y)
```
输出:
```
tensor([[2, 9, 1, 5],
[8, 4, 1, 4],
[4, 4, 9, 0]])
tensor([[[4, 4, 4, 4],
[0, 4, 4, 0],
[4, 4, 3, 1]],
[[3, 3, 2, 2],
[2, 3, 3, 2],
[3, 2, 0, 3]]])
```