mnist手写数字识别PYTHON
时间: 2024-12-30 17:35:01 浏览: 7
### 使用Python实现MNIST手写数字识别
#### 加载必要的库和模块
为了完成这个任务,需要导入一些重要的库来处理数据、构建神经网络以及进行训练。
```python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
```
#### 准备环境配置
确保使用合适的硬件资源(CPU/GPU),并设置随机种子以保证实验可重复性[^1]。
```python
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")
torch.manual_seed(0)
if device == "cuda":
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```
#### 获取与预处理数据集
下载MNIST数据集,并对其进行标准化转换操作。这一步骤对于提高模型性能至关重要[^3]。
```python
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size * 2, shuffle=False)
```
#### 构建卷积神经网络(CNN)
定义一个简单的CNN架构用于图像分类任务。该网络由两个卷积层组成,后面跟着全连接层来进行最终预测。
```python
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5, 5))
self.pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5))
self.fc1 = torch.nn.Linear(in_features=64*4*4, out_features=512)
self.fc2 = torch.nn.Linear(in_features=512, out_features=10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64*4*4) # Flatten tensor
x = torch.relu(self.fc1(x))
output = self.fc2(x)
return output
model = Net().to(device=device)
```
#### 定义损失函数和优化器
选择交叉熵作为损失计算方式,并采用Adam算法更新权重参数。
```python
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
```
#### 开始训练过程
设定迭代次数和其他超参后启动训练循环,在每轮结束后评估当前模型的表现情况。
```python
epochs = 5
for epoch in range(epochs):
running_loss = 0.
for i, data in enumerate(train_loader, start=0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=-1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Accuracy of the network on the 10000 test images: {} %%'.format(
accuracy))
```
阅读全文