匹配输入数据的维度与模型的期望维度的代码
时间: 2024-03-01 09:54:53 浏览: 23
以下是一个简单的示例代码,展示了如何匹配输入数据的维度与模型的期望维度:
```python
import torch
import torch.nn as nn
# 定义一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# 定义输入数据
x = torch.randn(3, 10)
# 创建模型实例
model = MyModel()
# 检查模型的输入维度
print(model.fc1.in_features) # 输出为 10
# 检查输入数据的维度
print(x.size()) # 输出为 torch.Size([3, 10])
# 将输入数据调整为模型期望的维度
x = x.view(-1, model.fc1.in_features)
# 检查调整后的输入数据的维度
print(x.size()) # 输出为 torch.Size([3, 10])
# 将输入数据输入模型进行计算
output = model(x)
# 检查输出数据的维度
print(output.size()) # 输出为 torch.Size([3, 2])
```
在这个示例代码中,我们定义了一个简单的模型 `MyModel`,它包含两个全连接层,输入维度为 10,输出维度为 2。然后,我们定义了一个输入数据 `x`,它的维度为 (3, 10)。我们通过查看模型的输入维度和输入数据的维度来检查它们是否匹配。接下来,我们使用 `view()` 方法将输入数据调整为模型期望的维度,然后将其输入模型进行计算。最后,我们检查输出数据的维度是否与期望相同。