pytorch训练mnist
时间: 2024-10-05 17:03:10 浏览: 11
PyTorch是一种流行的深度学习框架,非常适合用于训练MNIST手写数字识别任务。MNIST是一个基础的数据集,包含60,000张28x28像素的手写数字训练图片和10,000张测试图片。
以下是使用PyTorch训练MNIST的基本步骤:
1. **数据预处理**:
- 导入所需的库(如`torch`, `torchvision`),并下载MNIST数据集。
```python
import torch
from torchvision.datasets import MNIST
train_dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
```
2. **数据加载**:
- 使用`DataLoader`将数据划分为小批次。
```python
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
```
3. **构建模型**:
- 通常选择卷积神经网络(CNN),例如LeNet或简单的全连接网络。
```python
model = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(7 * 7 * 32, 128),
nn.ReLU(),
nn.Linear(128, 10) # 输出层,10个节点对应0-9的分类
)
```
4. **定义损失函数和优化器**:
- Softmax交叉熵作为损失函数,SGD或Adam等优化器。
```python
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
```
5. **训练循环**:
- 循环遍历每个训练批次,前向传播、计算损失、反向传播和更新权重。
```python
for epoch in range(num_epochs):
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
6. **评估模型**:
- 测试集上运行模型,查看准确率。
```python
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Test Accuracy: {correct / total * 100:.2f}%")
```