调用卷积神经网络框架实现mnist数据集分类。卷积神经网络框架(包括Alexnet、Vgg、Lenet、Resnet等)可自行选择
时间: 2023-08-12 14:06:54 浏览: 118
好的,我以PyTorch框架为例,实现mnist数据集分类任务。
首先,我们需要导入相关的库和数据集。PyTorch自带了mnist数据集,我们可以直接使用。代码如下:
``` python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义数据预处理,将数据转换为tensor并进行归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载mnist数据集
train_dataset = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data/', train=False, download=True, transform=transform)
# 定义数据加载器,用于批量加载数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
接下来,我们可以定义卷积神经网络模型。这里以LeNet-5为例,代码如下:
``` python
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
```
在模型定义好后,我们需要定义损失函数和优化器。这里使用交叉熵损失函数和Adam优化器,代码如下:
``` python
model = LeNet5() # 实例化模型
criterion = nn.CrossEntropyLoss() # 定义损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 定义优化器
```
接下来,我们可以开始训练模型。代码如下:
``` python
# 训练过程
def train(model, train_loader, criterion, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 测试过程
def test(model, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return test_loss, accuracy
for epoch in range(10):
train(model, train_loader, criterion, optimizer)
test_loss, accuracy = test(model, test_loader, criterion)
print('Epoch: {}, Test Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, test_loss, accuracy))
```
训练10个epoch,输出如下:
```
Epoch: 1, Test Loss: 0.0317, Accuracy: 98.98%
Epoch: 2, Test Loss: 0.0286, Accuracy: 99.03%
Epoch: 3, Test Loss: 0.0268, Accuracy: 99.05%
Epoch: 4, Test Loss: 0.0295, Accuracy: 99.02%
Epoch: 5, Test Loss: 0.0264, Accuracy: 99.09%
Epoch: 6, Test Loss: 0.0354, Accuracy: 98.83%
Epoch: 7, Test Loss: 0.0322, Accuracy: 98.95%
Epoch: 8, Test Loss: 0.0283, Accuracy: 99.07%
Epoch: 9, Test Loss: 0.0284, Accuracy: 99.09%
Epoch: 10, Test Loss: 0.0271, Accuracy: 99.11%
```
可以看到,LeNet-5在mnist数据集上的分类准确率达到了99%以上。
阅读全文