写一个基于pytorch的手写英文字母识别系统,要求输出训练结果,可以通过函数调用训练好的模型进行新图片的预测,有目录结构
时间: 2024-02-12 12:04:02 浏览: 87
好的,这里是一个基于PyTorch的手写英文字母识别系统,包括训练和预测两个部分,同时还有目录结构和详细的注释。
目录结构:
```
handwriting_recognition/
|---- data/ # 存放EMNIST数据集
|---- model/ # 存放训练好的模型
|---- src/ # 存放源代码
| |---- train.py # 训练模型
| |---- predict.py # 预测新图片
|---- test.png # 测试图片
|---- README.md # 说明文档
```
train.py:
```python
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.fc1 = nn.Linear(4*4*64, 256)
self.fc2 = nn.Linear(256, 26)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 4*4*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train():
# 数据集预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize((0.1307,), (0.3081,)) # 标准化
])
# 加载EMNIST数据集
emnist_dataset = torchvision.datasets.EMNIST(root='./data', split='letters', train=True,
download=True, transform=transform)
# 划分训练集、测试集
train_dataset, test_dataset = random_split(emnist_dataset, [60000, 6974])
# 定义模型
model = Net()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for epoch in range(1, 11):
train_loss = 0.0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
train_loss /= len(train_loader.dataset)
print('Epoch: {}, Training Loss: {:.6f}'.format(epoch, train_loss))
# 测试模型
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print('Accuracy on test set: {:.2f}%'.format(100.0*correct/total))
# 保存模型
torch.save(model.state_dict(), './model/handwriting.pth')
if __name__ == '__main__':
train()
```
predict.py:
```python
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.fc1 = nn.Linear(4*4*64, 256)
self.fc2 = nn.Linear(256, 26)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 4*4*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def predict(image_path):
# 加载模型
model = Net()
model.load_state_dict(torch.load('./model/handwriting.pth'))
# 图片预处理
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
image = Image.open(image_path)
image = transform(image).unsqueeze(0)
# 预测图片类别
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
if __name__ == '__main__':
image_path = './test.png'
predicted = predict(image_path)
print('Predicted class: {}'.format(predicted))
```
注意:在运行predict.py之前,需要先运行train.py来训练模型并保存到本地文件中。同时,要确保test.png文件存在于当前目录中。
希望对你有帮助!
阅读全文