将ImageNet上预训练的模型(VGG、ResNet)适配到CIFAR10分类任务,并在CIFAR10数据集训练,至模型收敛/达到停机条件,写出代码
时间: 2024-10-16 17:16:58 浏览: 45
将预训练的ImageNet模型(比如VGG或ResNet)迁移到CIFAR-10这种规模更小、类别较少的数据集上,需要一些步骤调整。以下是一个简化版的Python代码片段,使用PyTorch库,假设已经安装了`torchvision`, `torch`, 和 `numpy`:
```python
import torch
from torchvision import models, transforms, datasets
from torch import nn, optim
# 定义CIFAR10相关的数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224), # ImageNet模型输入尺寸为224x224
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 对应ImageNet的归一化值
])
# 加载CIFAR10数据集
cifar10_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 创建DataLoader
train_loader = torch.utils.data.DataLoader(cifar10_train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar10_val_dataset, batch_size=128, shuffle=False)
# 使用ImageNet的预训练模型(如VGG16)
model_vgg = models.vgg16(pretrained=True) # 可能需要修改最后一层以适应CIFAR10的10个类别
num_features = model_vgg.classifier[6].in_features # 获取最后全连接层之前的特征维度
model_vgg.classifier[6] = nn.Linear(num_features, 10) # 更改为有10个输出节点的线性层
# 使用交叉熵损失函数和Adam优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_vgg.parameters(), lr=0.001) # 学习率设置可能需要调整
# 训练过程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vgg.to(device)
for epoch in range(epochs): # 设定合适的迭代次数(epochs)
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # 清零梯度
outputs = model_vgg(inputs)
loss = criterion(outputs, labels)
loss.backward() # 反向传播
optimizer.step() # 更新权重
# 每个epoch结束后验证并记录性能指标(如accuracy)
with torch.no_grad():
val_loss, val_acc = evaluate(model_vgg, val_loader)
print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
# 调整学习率、早停策略或其他超参数,直到模型收敛或达到预定的停止条件
```
这个例子中,`evaluate`函数是一个辅助函数用于计算验证集上的损失和准确率:
```python
def evaluate(model, data_loader):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in data_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += criterion(outputs, labels).item() * inputs.size(0)
return running_loss / total, correct / total
```
阅读全文