介绍一下torch.randint及其参数
时间: 2023-12-30 15:05:09 浏览: 148
Pythorch中torch.nn.LSTM()参数详解
5星 · 资源好评率100%
torch.randint是PyTorch库中的一个函数,用于生成随机整数张量。它的基本语法为:
`torch.randint(low=0, high, size, dtype=None, layout=torch.strided, device=None, requires_grad=False)`
其中,参数的含义如下:
- low(int):随机整数的最小值。
- high(int):随机整数的最大值(不包括此值)。
- size(tuple):生成的张量的形状。
- dtype(torch.dtype,可选):生成的张量的数据类型。默认为torch.int64。
- layout(torch.layout,可选):生成的张量的布局。默认为torch.strided。
- device(torch.device,可选):生成的张量所在的设备。默认为None,表示使用默认设备。
- requires_grad(bool,可选):是否需要计算梯度。默认为False。
例如,以下代码生成了一个10行5列的随机整数张量,取值范围为0到9:
```
import torch
x = torch.randint(low=0, high=10, size=(10, 5))
print(x)
```
输出如下:
```
tensor([[4, 1, 5, 8, 6],
[1, 6, 6, 3, 5],
[0, 2, 6, 9, 2],
[2, 6, 6, 8, 8],
[7, 7, 2, 3, 9],
[1, 0, 8, 2, 9],
[1, 2, 9, 2, 8],
[4, 6, 3, 0, 2],
[4, 6, 3, 1, 4],
[1, 6, 4, 3, 4]])
```
阅读全文