torch.flatten()
时间: 2023-10-31 20:57:17 浏览: 99
0460-极智开发-解读torch.flatten算子
5星 · 资源好评率100%
`torch.flatten()` 是 PyTorch 中用于将多维张量展平为一维张量的函数。它可以用于在神经网络中处理输入数据的展平操作,使得输入数据适合全连接层等需要一维输入的层。
下面是 `torch.flatten()` 的基本用法示例:
```python
import torch
# 创建一个多维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将多维张量展平为一维张量
flatten_x = torch.flatten(x)
# 输出结果
print(flatten_x)
```
在这个示例中,我们首先创建了一个多维张量 `x`,其形状为 `(2, 3)`。然后,我们通过调用 `torch.flatten(x)` 将多维张量 `x` 展平为一维张量 `flatten_x`。最后,我们打印输出结果。
`torch.flatten()` 的默认行为是按行展平多维张量。也就是说,它会按照从左到右、从上到下的顺序将元素排列为一维张量。如果需要按列展平,可以指定参数 `start_dim` 和 `end_dim`,以控制展平的维度范围。
希望这个解答对你有帮助!
阅读全文