pytorch训练好的模型保存和使用
时间: 2023-09-12 15:13:25 浏览: 66
在 PyTorch 中,可以使用 `torch.save()` 函数将训练好的模型保存到磁盘上。该函数需要传入两个参数:模型的状态字典和文件名。
以下是保存模型的示例代码:
```python
import torch
# 假设已经训练好了一个模型,保存模型
model_state = model.state_dict()
torch.save(model_state, 'model.pth')
```
加载模型时,可以使用 `torch.load()` 函数将模型状态字典从磁盘中加载出来,并使用 `load_state_dict()` 方法将模型参数加载到模型中。以下是加载模型的示例代码:
```python
import torch
# 加载模型
model = Model() # 这里的 Model 是你定义的模型类
model_state = torch.load('model.pth')
model.load_state_dict(model_state)
```
在加载模型时,需要确保模型类定义中的参数和保存的模型状态字典中的参数名称和顺序一致。如果有不一致的地方,可以在加载模型时使用 `strict=False` 参数来禁用严格模式,这样可以忽略一些不一致的参数。
```python
import torch
# 加载模型(禁用严格模式)
model = Model() # 这里的 Model 是你定义的模型类
model_state = torch.load('model.pth')
model.load_state_dict(model_state, strict=False)
```
使用加载好的模型进行预测时,只需要将数据传入模型即可。以下是使用模型进行预测的示例代码:
```python
import torch
# 加载模型
model = Model() # 这里的 Model 是你定义的模型类
model_state = torch.load('model.pth')
model.load_state_dict(model_state)
# 使用模型进行预测
input_data = torch.randn(1, 3, 224, 224) # 假设输入数据为 1 张 3 通道的 224x224 图片
output = model(input_data)
```