resnet18训练mnist数据集
时间: 2023-11-10 17:05:50 浏览: 196
要使用ResNet-18来训练MNIST数据集,首先需要导入相应的库,并获取MNIST数据集。然后,需要调整ResNet-18模型的全连接层的输出,以适应MNIST数据集的10个类别。接下来,可以开启训练过程,并在训练过程中显示损失值。训练完成后,可以保存训练好的模型权重文件。最后,在测试集上测试训练后模型的准确率。
下面是一种可能的实现:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
# 导入MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
# 创建ResNet-18模型
model = models.resnet18(pretrained=True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10)
# 设置优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# 开始训练
num_epochs = 10
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print("Epoch {} - Training loss: {:.4f}".format(epoch+1, running_loss / len(train_loader)))
# 保存训练好的模型权重文件
torch.save(model.state_dict(), 'resnet18_mnist.pth')
# 在测试集上测试模型准确率
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print("Test accuracy: {:.2f}%".format(accuracy))
```
阅读全文