CNN手写数字识别csv读取保存加载pytorch模型
时间: 2023-08-04 11:03:38 浏览: 113
基于pytorch实现的MNIST+CNN模型实现对手写数字的识别代码+报告.zip
5星 · 资源好评率100%
你可以按照以下步骤在PyTorch中实现CNN手写数字识别,包括使用CSV文件进行数据读取、保存和加载模型:
1. 导入所需的库和模块:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, Dataset
```
2. 创建一个自定义的数据集类,用于读取CSV文件中的数据:
```python
class DigitDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data.iloc[idx, 1:].values.reshape(28, 28).astype('float32') / 255.0
label = self.data.iloc[idx, 0]
return image, label
```
3. 定义CNN模型:
```python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(7*7*32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
```
4. 定义训练函数和测试函数:
```python
def train(model, train_loader, criterion, optimizer):
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images.unsqueeze(1))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
def test(model, test_loader):
model.eval()
correct = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images.unsqueeze(1))
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(test_loader.dataset)
return accuracy
```
5. 加载数据集并创建数据加载器:
```python
train_dataset = DigitDataset('train.csv')
test_dataset = DigitDataset('test.csv')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
```
6. 创建CNN模型实例、损失函数和优化器:
```python
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
7. 进行训练和测试:
```python
num_epochs = 10
for epoch in range(num_epochs):
train(model, train_loader, criterion, optimizer)
accuracy = test(model, test_loader)
print(f'Epoch {epoch+1}, Test Accuracy: {accuracy}')
torch.save(model.state_dict(), 'digit_model.pt')
```
8. 加载保存的模型并进行预测:
```python
model = CNN()
model.load_state_dict(torch.load('digit_model.pt'))
# 假设有一个名为image的张量用于预测
output = model(image.unsqueeze(0).unsqueeze(0))
_, predicted = torch.max(output.data, 1)
print(f'Predicted digit: {predicted.item()}')
```
这就是使用CSV文件进行手写数字识别的基本步骤。你可以根据自己的需求进行修改和优化。
阅读全文