写一个迁移学习的图像分类代码
时间: 2023-06-08 12:07:29 浏览: 54
好的,以下是一个迁移学习图像分类的代码示例:
```python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
# 定义数据预处理步骤,例如将图像转换为张量和归一化等
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# 加载预训练模型(ResNet18)。这里使用不包含分类器的模型。
net = torchvision.models.resnet18(pretrained=True)
# 冻结所有层。仅训练添加的分类器
for param in net.parameters():
param.requires_grad = False
# 替换分类器
net.fc = nn.Linear(512, 10)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.fc.parameters(), lr=0.001, momentum=0.9)
# 训练分类器
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 在分类器上前向传递
outputs = net(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印统计信息
running_loss += loss.item()
if i % 200 == 199:
print('[%d, %5d] training loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
# 测试网络性能
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
```
这个示例使用了迁移学习的思想,将一个在ImageNet上预训练的ResNet18模型用于CIFAR10图像分类任务。它首先加载预训练模型,并用新的分类器替换原来的最后一层。然后,仅训练新添加的分类器。其余的参数都被冻结,以便保留预训练模型的特征提取能力。最后,该模型在测试集上获得了超过90%的准确率。该示例仅仅是用来展示迁移学习的基本思想和用法,实际上,还需要进一步的方法优化,比如学习率调整或fine-tuning等。