torch.randn()函数的用法
时间: 2023-10-14 08:00:13 浏览: 140
`torch.randn()` 函数是 PyTorch 中的一个函数,用于生成服从标准正态分布(均值为0,方差为1)的随机数。它的语法如下:
```python
torch.randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
```
参数说明:
- `*size`:可变参数,用于指定返回张量的形状。例如,`torch.randn(2, 3)` 会生成一个形状为 (2, 3) 的张量。
- `out`:可选参数,用于指定输出张量的位置。
- `dtype`:可选参数,用于指定输出张量的数据类型。默认为 `None`,即使用默认的数据类型。
- `layout`:可选参数,用于指定输出张量的布局,默认为 `torch.strided`。
- `device`:可选参数,用于指定输出张量所在的设备。默认为 `None`,即使用默认的设备(通常是 CPU)。
- `requires_grad`:可选参数,用于指定输出张量是否需要梯度计算。默认为 `False`。
示例代码:
```python
import torch
# 生成一个形状为 (2, 3) 的随机张量
x = torch.randn(2, 3)
print(x)
```
以上代码中,`torch.randn(2, 3)` 生成了一个形状为 (2, 3) 的随机张量,并将其赋值给变量 `x`。最后,使用 `print()` 函数打印出 `x` 的值。
注意:`torch.randn()` 生成的随机数是从标准正态分布中抽取的,即均值为0,方差为1。如果你需要生成服从其他分布的随机数,可以使用其他函数,例如 `torch.rand()`(生成在[0, 1)范围内的均匀分布随机数)或 `torch.normal()`(生成服从指定均值和方差的正态分布随机数)。
阅读全文