cnn代码,使用pytorch,将45x235的低分辨率输入映射为150x675的高分辨率输出,训练数据为700个,测试数据为300个,训练数据是存放在编号为1-700的excel表格一列中,测试数据存放在编号为701-1000的excel表格一列中。写出一个读取excel并且进行训练的数据,并且将训练误差和测试误差plot出来
时间: 2024-04-08 16:30:57 浏览: 74
以下是一个使用PyTorch进行CNN训练的示例代码,包括读取Excel数据和绘制训练误差和测试误差的图表:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
# 定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(32 * 45 * 235, 150 * 675)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 读取训练数据
train_data = pd.read_excel('data.xlsx', sheet_name='Sheet1', usecols=[0], nrows=700)
train_data = torch.tensor(train_data.values.reshape(-1, 1, 45, 235), dtype=torch.float32)
# 读取测试数据
test_data = pd.read_excel('data.xlsx', sheet_name='Sheet1', usecols=[0], skiprows=700)
test_data = torch.tensor(test_data.values.reshape(-1, 1, 45, 235), dtype=torch.float32)
# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 10
train_losses = []
test_losses = []
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for i, data in enumerate(train_data):
optimizer.zero_grad()
outputs = model(data.unsqueeze(0))
loss = criterion(outputs, data)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_losses.append(train_loss / len(train_data))
model.eval()
test_loss = 0.0
with torch.no_grad():
for i, data in enumerate(test_data):
outputs = model(data.unsqueeze(0))
loss = criterion(outputs, data)
test_loss += loss.item()
test_losses.append(test_loss / len(test_data))
# 绘制训练误差和测试误差的图表
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
```
请确保你已经安装了相应的库(如torch、pandas和matplotlib)并将训练数据保存在名为"data.xlsx"的Excel文件中的第一个工作表中。如果你的数据格式有所不同,请根据实际情况进行修改。运行上述代码后,将会得到训练误差和测试误差的图表。
阅读全文