帮我写一个PyTorch 用MNIST数据集训练的手写字体识别程序的代码
时间: 2023-10-04 15:03:20 浏览: 109
pytorch实现MNIST手写体识别
好的,以下是一个简单的基于PyTorch和MNIST数据集的手写数字识别程序的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 10
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.pooling = nn.MaxPool2d(2)
self.fc = nn.Linear(320, 10)
def forward(self, x):
x = self.pooling(nn.functional.relu(self.conv1(x)))
x = self.pooling(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 320)
x = self.fc(x)
return nn.functional.log_softmax(x, dim=1)
# 实例化模型并定义优化器和损失函数
model = Net()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5)
criterion = nn.CrossEntropyLoss()
# 训练模型
for epoch in range(num_epochs):
for i, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if i % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, i * len(data), len(train_loader.dataset),
100. * i / len(train_loader), loss.item()))
# 测试模型
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
```
这个程序中,我们使用了一个简单的卷积神经网络,包括两个卷积层和一个全连接层。在训练过程中,我们使用了随机梯度下降优化器和交叉熵损失函数。在测试过程中,我们计算了模型的准确率和测试损失。
阅读全文