利用detach()实现迁移学习的小案例
时间: 2023-07-17 17:04:03 浏览: 50
下面是一个使用`detach()`方法实现迁移学习的简单案例:
假设我们有一个已经预训练好的模型,用于图像分类任务。我们希望将这个模型迁移到一个新的任务上,但是只更新最后一层的参数。
```python
import torch
import torch.nn as nn
# 加载预训练模型
pretrained_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
# 冻结所有参数
for param in pretrained_model.parameters():
param.requires_grad = False
# 替换最后一层全连接层
pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, num_classes)
# 解冻最后一层的参数
for param in pretrained_model.fc.parameters():
param.requires_grad = True
# 创建优化器,只优化最后一层的参数
optimizer = torch.optim.SGD(pretrained_model.fc.parameters(), lr=0.001, momentum=0.9)
# 训练过程
for epoch in range(num_epochs):
for images, labels in dataloader:
# 前向传播
outputs = pretrained_model(images)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播并更新梯度
optimizer.zero_grad()
loss.backward()
# 分离计算图,只更新最后一层的参数
with torch.no_grad():
pretrained_model.fc.weight.grad = pretrained_model.fc.weight.grad.detach()
pretrained_model.fc.bias.grad = pretrained_model.fc.bias.grad.detach()
# 更新参数
optimizer.step()
```
在上述案例中,我们首先加载了一个预训练的ResNet-18模型,并冻结了所有参数。然后,我们替换了模型的最后一层全连接层,并解冻了这一层的参数。接下来,我们创建了一个只优化最后一层参数的优化器。在训练过程中,我们首先进行前向传播和损失计算,然后进行反向传播和梯度更新。在更新梯度之前,我们使用`detach()`方法分离计算图,只更新最后一层的参数,而不会对其他层的参数进行更新。这样就实现了迁移学习,只更新模型的最后一层参数,而保持其他层的预训练参数不变。