CNN手写数字识别csv训练集读取保存加载pytorch模型测试csv数据集和单张图片
时间: 2024-04-17 15:24:20 浏览: 182
你可以使用以下代码来读取、保存和加载CSV格式训练集,并使用PyTorch模型对CSV数据集和单张图片进行测试:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# 创建自定义数据集类
class CustomDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data.iloc[idx, 1:].values.astype('float32')
label = self.data.iloc[idx, 0]
return torch.tensor(image), torch.tensor(label)
# 创建CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(16*14*14, 10)
def forward(self, x):
x = x.view(-1, 1, 28, 28) # 转换为图片的形状
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(-1, 16*14*14) # 展平特征图
x = self.fc(x)
return x
# 读取训练集并创建数据加载器
train_dataset = CustomDataset('train.csv')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 训练并保存模型
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model.pth') # 保存模型参数
# 加载模型并测试CSV数据集
model = CNN()
model.load_state_dict(torch.load('model.pth')) # 加载模型参数
test_dataset = CustomDataset('test.csv')
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
model.eval()
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
print(predicted) # 打印预测结果
# 测试单张图片
from PIL import Image
image_path = 'image.jpg'
image = Image.open(image_path).convert('L') # 转换为灰度图
image = image.resize((28, 28)) # 调整大小为28x28
image = torch.tensor(image, dtype=torch.float32)
image = image.reshape(1, 1, 28, 28) # 转换为图片的形状
model.eval()
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
print(predicted) # 打印预测结果
```
请确保已安装所需的库,如torch、torchvision、pandas和PIL。你需要将训练集和测试集的CSV文件命名为'train.csv'和'test.csv',并放在与代码文件相同的目录下。同时,将要测试的单张图片命名为'image.jpg',并将其放在相同的目录下。
在训练时,模型会保存为'model.pth'文件。你可以根据自己的需求修改模型结构、训练参数和优化器等。
阅读全文
相关推荐
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![py](https://img-home.csdnimg.cn/images/20250102104920.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)