python中如何加载训练好的模型
时间: 2023-06-04 15:05:17 浏览: 494
python训练好的模型保存与加载
首先,需要使用Python中的相应库来加载训练好的模型文件。常用的是PyTorch和TensorFlow。具体的步骤如下:
1. 导入相应的库,例如PyTorch或TensorFlow;
2. 定义模型的架构或载入现有的模型的架构(如果有的话);
3. 载入模型权重或参数;
4. 对模型进行测试或使用。
以下是一个示例代码,假设模型文件为'model.pth'。
```python
import torch
import torchvision.models as models
# 定义模型架构
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 10)
# 载入模型权重
model.load_state_dict(torch.load('model.pth'))
# 测试模型
input_data = torch.randn(1, 3, 224, 224)
output = model(input_data)
```
在这个示例中,首先我们导入PyTorch库和resnet18模型,然后重新定义了网络的输出层(原来是1000个类别,我们改成了10个类别),最后载入了我们训练好的权重文件'model.pth'。我们还定义了一个输入数据input_data,然后把它输入到我们的模型中得到了输出结果output。
阅读全文