做2输入3输出的cnn非回归预测的代码
时间: 2024-02-24 13:56:04 浏览: 153
以下是一个简单的2输入3输出的CNN非回归预测的代码示例:
```python
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(2, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(32 * 7 * 7, 3)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = self.fc(x)
return x
# 定义模型和优化器
model = CNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 准备数据
inputs = torch.randn(10, 2, 28, 28)
labels = torch.randn(10, 3)
# 训练模型
for epoch in range(100):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
# 测试模型
with torch.no_grad():
test_inputs = torch.randn(1, 2, 28, 28)
test_outputs = model(test_inputs)
print('Test Output:', test_outputs)
```
这个代码示例定义了一个简单的CNN模型,该模型接受两个28x28的输入图像,并输出三个预测值。训练过程中使用均方误差作为损失函数,Adam优化器进行优化。最后,使用一个随机的输入图像进行测试,输出该模型的预测结果。
阅读全文