pytorch自己手写数字进行识别的代码包括模型训练和图片处理
时间: 2024-10-01 18:10:57 浏览: 25
基于PyTorch的手写数字识别案例实例代码
5星 · 资源好评率100%
PyTorch 自己手写数字识别的基本步骤通常包括数据预处理、构建神经网络模型、训练模型以及测试阶段。这里是一个简化版的示例代码:
首先,安装必要的库并导入它们:
```python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
```
1. **数据预处理**:
```python
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,))]) # 归一化
# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
```
2. **构建模型**:
```python
class HandwrittenDigitRecognizer(torch.nn.Module):
def __init__(self):
super(HandwrittenDigitRecognizer, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = torch.nn.Linear(32 * 14 * 14, 128)
self.fc2 = torch.nn.Linear(128, 10) # 输出层,10表示0-9共10个类别
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(-1, 32 * 14 * 14) # 展平以便输入全连接层
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = HandwrittenDigitRecognizer()
criterion = torch.nn.CrossEntropyLoss() # 损失函数
optimizer = torch.optim.Adam(model.parameters()) # 使用Adam优化器
```
3. **模型训练**:
```python
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
optimizer.zero_grad() # 清零梯度
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward() # 反向传播
optimizer.step() # 更新权重
# 训练阶段结束,打印一些进度信息
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
```
4. **模型测试**:
```python
with torch.no_grad():
correct_count = 0
total_count = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total_count += labels.size(0)
correct_count += (predicted == labels).sum().item()
accuracy = 100 * correct_count / total_count
print(f"Test Accuracy: {accuracy:.2f}%")
```
阅读全文