nn.flatten
时间: 2023-10-23 10:35:16 浏览: 93
对Pytorch中nn.ModuleList 和 nn.Sequential详解
nn.flatten 是一个在神经网络中常用的操作,用于将多维输入数据展平为一维。在深度学习中,我们经常需要将输入数据转换为一维向量,以便进行全连接层或其他一维操作。 nn.flatten 可以看作是一个将输入数据进行展开的操作,将多维数组铺平为一维向量。
在 PyTorch 中,nn.flatten 可以通过 nn.Flatten() 模块来实现。例如:
```python
import torch
import torch.nn as nn
flatten = nn.Flatten()
input_data = torch.randn(32, 3, 64, 64) # 输入数据为32个大小为 [3, 64, 64] 的样本
output_data = flatten(input_data) # 将输入数据展平为一维向量
print(output_data.size()) # 输出展平后的大小: torch.Size([32, 12288])
```
上述代码中,我们首先定义了 nn.Flatten() 模块,然后使用该模块对输入数据进行展平操作。展平后的输出数据大小为 [32, 12288],其中 32 是样本数量,12288 是展平后的一维向量长度。
阅读全文