请帮我写一段用pytorch搭建卷积神经网络实现MNIST分类实验的结果分析
时间: 2024-11-30 12:18:20 浏览: 13
基于Pytorch框架手动构建LeNet网络实现MNIST数据集图像识别分类源码+数据集+模型.zip
使用PyTorch构建卷积神经网络(Convolutional Neural Network, CNN)对MNIST手写数字数据集进行分类是一个常见的深度学习入门案例。首先,让我们简述一下实验的基本步骤和结果分析:
1. **数据加载**:从PyTorch的`torchvision`库导入`MNIST`数据集,进行预处理,包括归一化像素值到0-1范围。
```python
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
```
2. **模型定义**:创建一个包含几个卷积层、池化层以及全连接层的CNN模型。
```python
import torch.nn as nn
from torchsummary import summary
model = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10) # 输出层,对应MNIST的10个类别
)
summary(model, input_size=(1, 28, 28))
```
3. **训练过程**:定义损失函数(如交叉熵),优化器(如Adam)并训练模型。
```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
```
4. **测试与评估**:在验证集上运行模型并计算准确率。
```python
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Accuracy on the test set: {accuracy}%")
```
结果分析:
- **高精度**:如果模型达到90%以上的准确率,说明它已经成功地学习了MNIST数据集的特征并能有效地分类新的手写数字图片。
- **过拟合与欠拟合**:检查训练集和验证集的表现差异,如果训练集准确度很高而验证集低,则可能存在过拟合;反之则可能是欠拟合。
- **调整超参数**:可以通过改变学习率、网络结构、批量大小等来优化模型性能。
- **时间与资源消耗**:记录训练时间和GPU利用率,有助于评估模型的效率和是否适合实际部署。
阅读全文