CNN手写数字识别csv格式训练集读取保存训练好的pytorch模型,另一个文件加载调用模型测试csv数据集和单张图片
时间: 2024-01-14 14:04:40 浏览: 181
您可以按照以下步骤进行CNN手写数字识别的训练、保存和加载模型,并使用模型测试CSV数据集和单张图片。
1. 训练和保存模型:
- 首先,导入所需的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
```
- 定义一个简单的CNN模型:
```python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(14*14*16, 10)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.maxpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
```
- 加载MNIST数据集并进行预处理:
```python
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = MNIST(root='./', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./', train=False, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
```
- 初始化模型、损失函数和优化器:
```python
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
- 进行训练并保存模型:
```python
num_epochs = 10
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}")
torch.save(model.state_dict(), 'model.pt')
```
2. 加载和使用模型进行预测:
- 导入所需的库:
```python
import pandas as pd
from PIL import Image
```
- 加载模型并进行预测:
```python
model = CNN()
model.load_state_dict(torch.load('model.pt'))
model.eval()
def predict(image):
image = transform(image).unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
```
3. 使用模型测试CSV数据集和单张图片:
- 测试CSV数据集:
```python
df = pd.read_csv('test.csv')
for i in range(len(df)):
image = Image.fromarray(df.iloc[i].values.reshape(28, 28).astype('uint8'))
prediction = predict(image)
print(f"Image {i+1}: Predicted Digit - {prediction}")
```
- 测试单张图片:
```python
image = Image.open('image.jpg').convert('L')
prediction = predict(image)
print(f"Predicted Digit - {prediction}")
```
这样,您就可以使用CSV数据集和单张图片测试已训练好的CNN模型了。请注意,这只是一个简单的示例,您可以根据自己的需求进行修改和扩展。
阅读全文