用pytorch写一个用预训练ResNet18模型训练MNIST的代码
时间: 2024-04-16 15:26:29 浏览: 156
当然可以!以下是使用预训练的ResNet18模来训练MNIST数据集的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
# 设置设备(CPU或GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 加载预训练的ResNet18模型(不包括最后一层)
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10) # 更改最后一层的输出为10个类别
# 将模型移动到设备上
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
def train(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计分类准确率
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
running_loss += loss.item()
train_loss = running_loss / len(dataloader)
train_acc = correct / total
return train_loss, train_acc
# 测试模型
def test(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 统计分类准确率
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
running_loss += loss.item()
test_loss = running_loss / len(dataloader)
test_acc = correct / total
return test_loss, test_acc
# 训练和评估模型
num_epochs = 10
for epoch in range(num_epochs):
train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
test_loss, test_acc = test(model, test_loader, criterion, device)
print(f"Epoch {epoch+1}/{num_epochs}:")
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
print()
```
这段代码首先加载MNIST数据集,并进行预处理。然后,使用预训练的ResNet18模型,将最后一层的输出更改为10个类别。接下来,定义损失函数和优化器。训练过程中,使用训练数据集对模型进行训练,并使用测试数据集对模型进行评估。最后,打印每个epoch的训练损失、训练准确率、测试损失和测试准确率。
请注意,这里的ResNet18模型是在ImageNet数据集上进行预训练的,因此可能需要更多的训练迭代来适应MNIST数据集。你可以根据需要调整超参数和训练迭代次数来获得更好的结果。
阅读全文