使用python代码演示 使用resnet34做图像分类 1.使用cifar10数据集:10个类别的颜色图像 2.在测试集上比较从零开始重新训练和直接使用预训练参数的效果有什么区别
时间: 2025-01-04 07:40:04 浏览: 6
在Python中使用ResNet34进行CIFAR-10图像分类,我们通常会利用PyTorch库,因为其提供了预训练模型和方便的API。以下是简单步骤的概述:
首先,确保已经安装了必要的库,如torch、torchvision和torch.nn:
```bash
pip install torch torchvision
```
然后,我们可以开始编写代码:
```python
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
# 1. 加载预训练的ResNet34模型并冻结权重
model = models.resnet34(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 2. 定义新的全连接层替换原有的分类层(如果有的话)
num_classes = 10 # CIFAR-10有10个类别
classifier = nn.Linear(model.fc.in_features, num_classes)
# 将新的分类层添加到模型
model.fc = classifier
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
# 3. 训练模型
# 由于我们只更新新添加的分类层,所以这里可以设置学习率较低,例如0.001
learning_rate = 0.001
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 开始训练...
epochs = 10
for epoch in range(epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
# 4. 测试模型
model.eval() # 转换为评估模式,关闭Dropout等
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"测试集上的准确率为: {accuracy}%")
# 比较从零开始重新训练和直接使用预训练参数的区别:
阅读全文