pycham实现mnist手写数字
时间: 2024-11-10 15:26:37 浏览: 18
SVM 实现MNIST手写数字图像识别
PyTorch是一个开源的人工智能库,特别适合于动态计算图模型如深度学习。要使用PyTorch实现MNIST(Mixed National Institute of Standards and Technology)手写数字识别任务,可以按照以下步骤操作:
1. 导入所需库:首先需要导入`torch`, `torchvision`以及数据处理和可视化工具。
```python
import torch
import torchvision
from torchvision import datasets, transforms
```
2. 数据预处理:对MNIST数据集进行归一化,并将其转换为张量。
```python
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
```
3. 创建数据加载器:使用`DataLoader`将数据分批加载到内存中。
```python
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```
4. 定义网络模型:使用PyTorch构建一个简单的卷积神经网络(CNN),例如LeNet、ResNet等。
```python
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 16, kernel_size=5), # 卷积层
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(16, 32, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Flatten(), # 将二维数组展平成一维
torch.nn.Linear(32 * 7 * 7, 10) # 全连接层
)
```
5. 损失函数和优化器选择:交叉熵损失函数适用于分类任务,Adam是一种常用的优化算法。
```python
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
```
6. 训练过程:通过训练循环迭代数据,更新网络参数。
```python
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
optimizer.zero_grad() # 清零梯度
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
```
7. 测试模型:在测试数据上评估模型性能。
```python
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Test Accuracy: {100 * correct / total}%")
```
阅读全文