基于pytorch卷积神经网络手写数字识别 自适应调整学习率的实现 代码
时间: 2023-09-16 16:11:53 浏览: 96
以下是基于PyTorch实现的手写数字识别卷积神经网络,包含自适应调整学习率的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义卷积神经网络模型
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc1 = nn.Linear(7*7*64, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc1(out)
out = self.fc2(out)
return out
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./', train=False, transform=transforms.ToTensor())
# 定义训练参数
batch_size = 100
num_epochs = 5
# 定义损失函数和优化器
model = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# 训练模型
total_step = len(train_dataset) // batch_size
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_dataset):
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每100步打印一次训练状态
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# 动态调整学习率
if (epoch+1) % 2 == 0:
for param_group in optimizer.param_groups:
param_group['lr'] /= 2
# 测试模型
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_dataset:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += 1
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
```
在训练过程中,每训练2个epoch,会将学习率除以2进行自适应调整。可以根据实际情况进行调整。
阅读全文