写一个基于pytorch的手写英文字母识别系统,要求输出训练结果,可以通过函数调用训练好的模型进行新图片的预测
时间: 2024-02-12 22:03:59 浏览: 109
好的,下面是一个简单的基于PyTorch的手写英文字母识别系统,包括训练和预测两个部分。具体实现步骤如下:
1. 准备数据集
我们使用EMNIST数据集,它包含手写字母和数字,总共有814,255个样本。我们只选择其中的英文字母部分,共有26个类别,每个类别有2400个样本。
```python
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
# 数据集预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize((0.1307,), (0.3081,)) # 标准化
])
# 加载EMNIST数据集
emnist_dataset = torchvision.datasets.EMNIST(root='./data', split='letters', train=True,
download=True, transform=transform)
# 划分训练集、测试集
train_dataset, test_dataset = random_split(emnist_dataset, [60000, 6974])
```
2. 定义模型
我们使用一个简单的卷积神经网络模型,包括两个卷积层和两个全连接层。
```python
import torch.nn as nn
import torch.nn.functional as F
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.fc1 = nn.Linear(4*4*64, 256)
self.fc2 = nn.Linear(256, 26)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 4*4*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
```
3. 训练模型
我们使用交叉熵损失函数和Adam优化器来训练模型。
```python
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
def train(model, train_loader, optimizer, criterion, epochs):
for epoch in range(1, epochs+1):
train_loss = 0.0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
train_loss /= len(train_loader.dataset)
print('Epoch: {}, Training Loss: {:.6f}'.format(epoch, train_loss))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
train(model, train_loader, optimizer, criterion, epochs=10)
```
4. 测试模型
我们使用测试集来评估模型的性能。
```python
# 测试模型
def test(model, test_loader):
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print('Accuracy on test set: {:.2f}%'.format(100.0*correct/total))
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
test(model, test_loader)
```
5. 保存模型
我们将训练好的模型保存到本地文件中。
```python
# 保存模型
torch.save(model.state_dict(), 'handwriting.pth')
```
6. 加载模型
我们可以通过读取本地文件来加载训练好的模型。
```python
# 加载模型
model.load_state_dict(torch.load('handwriting.pth'))
```
7. 预测新图片
我们可以通过调用训练好的模型来预测新图片的类别。
```python
# 预测新图片
def predict(model, image):
image = transform(image).unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
image = Image.open('test.png').convert('L')
predicted = predict(model, image)
print('Predicted class: {}'.format(predicted))
```
完整代码如下:
阅读全文