CNN手写数字识别csv读取保存加载pytorch模型测试数据集和单张图片
时间: 2024-04-18 17:24:18 浏览: 165
你可以按照以下步骤使用CNN模型进行手写数字识别,包括CSV文件的读取、保存和加载以及测试数据集和单张图片。
1. 导入所需的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
```
2. 定义CNN模型:
```python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(7*7*64, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(-1, 7*7*64)
x = self.fc1(x)
x = self.relu3(x)
x = self.fc2(x)
return x
model = CNN()
```
3. 读取CSV文件并准备数据集:
```python
class CustomDataset(Dataset):
def __init__(self, csv_path):
self.data = pd.read_csv(csv_path, header=None)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
label = self.data.iloc[index, 0]
image = self.data.iloc[index, 1:].values.reshape(28, 28).astype(np.uint8)
image = np.expand_dims(image, axis=0)
return image, label
csv_path = 'path/to/your/csv/file.csv'
dataset = CustomDataset(csv_path)
```
4. 划分训练集和测试集:
```python
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
```
5. 创建数据加载器:
```python
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```
6. 定义损失函数和优化器:
```python
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
7. 训练模型:
```python
num_epochs = 10
for epoch in range(num_epochs):
train_loss = 0.0
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images.float())
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
print(f"Epoch: {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")
```
8. 保存训练好的模型:
```python
torch.save(model.state_dict(), 'path/to/save/model.pth')
```
9. 加载已保存的模型并进行测试:
```python
model.load_state_dict(torch.load('path/to/save/model.pth'))
model.eval()
test_loss = 0.0
correct = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images.float())
loss = criterion(outputs, labels)
test_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset) * 100
print(f"Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%")
```
10. 使用模型进行单张图片的预测:
```python
from PIL import Image
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('L')
image = np.array(image)
image = np.expand_dims(image, axis=0)
image = torch.from_numpy(image).unsqueeze(0)
output = model(image.float())
_, predicted = torch.max(output.data, 1)
print(f"Predicted Label: {predicted.item()}")
```
请将 `path/to/your/csv/file.csv` 和 `path/to/save/model.pth` 替换为您的实际文件路径。另外,确保您的CSV文件以及要测试的单张图片符合要求。
阅读全文