X = torch.arange(16).view(2, 8) mask = (torch.rand(X.shape) < keep_prob).float() 详细注释
时间: 2024-05-13 11:20:26 浏览: 110
这是一段 PyTorch 代码。首先,使用 `torch.arange(16)` 创建一个长度为 16 的张量,再使用 `view(2, 8)` 将张量转换为形状为 2 行 8 列的张量。这个张量 X 的值为:
```
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
```
接下来,使用 `torch.rand(X.shape)` 创建一个与 X 形状相同的随机张量,其中每个元素都是从均匀分布 `[0, 1)` 中随机采样得到的。然后,将这个随机张量与 `keep_prob` 比较,生成一个布尔型张量,其中每个元素都有 `keep_prob` 的概率为 1,有 `(1 - keep_prob)` 的概率为 0。最后,将这个布尔型张量转换为浮点型张量,其中 1 对应着保留该元素,0 对应着丢弃该元素。这个浮点型张量称为掩码(mask),可以用于随机屏蔽部分输入,实现 dropout 操作。
阅读全文