训练好的ResNet-18模型的代码
时间: 2023-09-11 13:10:09 浏览: 89
pytorch resnet18 预训练模型
以下是使用PyTorch框架加载预训练的ResNet-18模型的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
# 加载预训练好的ResNet-18模型
model = models.resnet18(pretrained=True)
# 输出网络结构
print(model)
# 冻结所有卷积层的参数
for param in model.parameters():
param.requires_grad = False
# 替换最后一层全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
# 输出网络结构
print(model)
```
这段代码中,我们使用了`models.resnet18(pretrained=True)`来加载预训练好的ResNet-18模型。接着,我们将所有卷积层的参数都冻结,只训练最后一层全连接层。最后,我们用`nn.Linear`替换了ResNet-18模型的最后一层全连接层,以适应我们的分类任务。您可以根据自己的需求进行修改和调整。
阅读全文