李沐深度学习微调代码
时间: 2023-10-04 10:06:38 浏览: 122
你好!以下是一个示例代码,用于进行深度学习模型的微调(fine-tuning):
```python
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结所有参数
for param in model.parameters():
param.requires_grad = False
# 替换全连接层
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, num_classes)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
# 加载数据集
train_dataset = torchvision.datasets.ImageFolder('path/to/train/dataset', transform=...)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
# 训练模型
model.train()
for epoch in range(num_epochs):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 保存微调后的模型
torch.save(model.state_dict(), 'fine_tuned_model.pth')
```
这是一个简单的示例代码,你可以根据自己的需求进行微调。你需要替换代码中的`path/to/train/dataset`为你自己的训练数据集的路径,并根据你的数据集进行相应的数据预处理操作。
当然,这只是一个基本的示例,实际的微调过程可能还需要考虑其他因素,如学习率调整、数据增强等。具体的微调策略会根据你的任务和数据集的特点而有所不同。希望对你有所帮助!
阅读全文