加载MINIST数据集,构建模型并训练,在测试集中得出模型的准确率,保存模型并利用个人手写的图像进行分类判断。 要求1:使用PyTorch 的nn 模块进行训练和测试。 要求2:可以对单张图片进行分类判断。
时间: 2024-10-19 11:12:10 浏览: 40
用于手写数字识别的Demo,使用pyTorch训练Minist数据集,导出ONNX模型文件.zip
首先,为了加载MNIST数据集,你可以使用PyTorch提供的torchvision库,它包含了常用的图像数据集。以下是完成这个任务的基本步骤:
1. **导入所需的库**:
```python
import torch
from torch import nn, optim
from torchvision import datasets, transforms
```
2. **加载MNIST数据集**:
```python
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为0-1范围内的浮点数
transforms.Normalize((0.1307,), (0.3081,)) # 归一化
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 分割数据集为训练集和验证集(这里通常用验证集而不是测试集来进行早期停止等调整)
train_loader, val_loader, test_loader = torch.utils.data.random_split(train_dataset, [50000, 10000, 10000])
```
3. **构建神经网络模型**:
```python
model = nn.Sequential(
nn.Flatten(), # 展平输入
nn.Linear(28*28, 128), # 全连接层
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10) # 输出层,假设是10个类别(0-9)
)
```
4. **定义损失函数和优化器**:
```python
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
```
5. **训练模型**:
```python
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每个epoch结束后,评估模型在验证集上的性能
with torch.no_grad():
correct = 0
total = 0
for images, labels in val_loader:
predictions = torch.argmax(model(images), dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
print(f"Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {accuracy}")
```
6. **测试模型并在测试集上得到准确率**:
```python
with torch.no_grad():
correct_test = 0
total_test = 0
for images, labels in test_loader:
predictions = torch.argmax(model(images), dim=1)
correct_test += (predictions == labels).sum().item()
total_test += labels.size(0)
test_accuracy = correct_test / total_test
print(f"Test Accuracy: {test_accuracy}")
```
7. **保存模型**:
```python
torch.save(model.state_dict(), 'mnist_model.pth')
```
8. **手写图像分类**:
```python
def predict_image(image_path):
image = Image.open(image_path)
tensor_image = transform(image)
tensor_image.unsqueeze_(0) # 增加batch维度
model.eval() # 切换到推理模式
with torch.no_grad():
prediction = torch.argmax(model(tensor_image))
return "Your image is classified as: ", prediction.item()
# 使用predict_image函数对个人手写图像进行预测
```
阅读全文