paddle框架下,paddle.save保存模型,如何加载已保存的模型,请代码展示
时间: 2024-04-14 07:29:46 浏览: 183
保存加载对象的实现
在 PaddlePaddle 框架中,可以使用 `paddle.load` 函数加载已保存的模型。以下是一个简单的示例代码,展示了如何保存和加载模型:
```python
import paddle
# 定义模型
model = paddle.nn.Linear(10, 2)
# 保存模型
paddle.save(model.state_dict(), 'model.pdparams')
# 加载模型
loaded_model_state_dict = paddle.load('model.pdparams')
loaded_model = paddle.nn.Linear(10, 2)
loaded_model.set_state_dict(loaded_model_state_dict)
# 使用加载的模型进行推理
input_data = paddle.randn([1, 10])
output = loaded_model(input_data)
print(output)
```
在上述示例中,我们首先定义了一个简单的线性模型 `model`,然后使用 `paddle.save` 函数保存模型的状态字典(state_dict)到文件 `model.pdparams` 中。
接着,我们使用 `paddle.load` 函数加载已保存的模型参数,返回的结果是一个字典类型的对象。我们再次定义了一个与原模型相同结构的新模型 `loaded_model`,并使用 `set_state_dict` 方法将加载的参数设置给新模型。
最后,我们使用加载的模型进行推理,输入一个随机生成的张量 `input_data`,并打印输出结果 `output`。
需要注意的是,加载模型时需要确保加载的模型结构与原模型相匹配,否则可能会导致错误。此外,还可以通过指定 `map_location` 参数来选择加载模型的设备(如 CPU 或 GPU)。
更多关于模型保存和加载的信息,可以参考 PaddlePaddle 官方文档。
阅读全文