pytorch 如何保存训练好的模型并调用
时间: 2023-09-10 18:10:40 浏览: 108
Pytorch——保存训练好的模型参数
5星 · 资源好评率100%
您可以使用 PyTorch 提供的 torch.save() 函数来保存训练好的模型。使用该函数时,需要指定要保存的模型、保存路径和文件名称。下面是一个保存模型的示例代码:
```
import torch
# 定义模型
model = MyModel()
# 训练模型...
# 保存模型
torch.save(model.state_dict(), 'model.pt')
```
在加载模型时,可以使用 torch.load() 函数来加载模型的参数,然后将其应用于新的模型实例中。下面是一个加载模型的示例代码:
```
import torch
# 定义模型结构
model = MyModel()
# 加载模型参数
model.load_state_dict(torch.load('model.pt'))
# 使用模型进行推理
output = model(input)
```
请注意,使用 torch.load() 加载模型时,需要设置 map_location 参数以指定设备类型(例如 CPU 或 GPU),否则可能会出现错误。例如,如果您的模型在 GPU 上训练并保存,则可以使用以下代码将其加载到 CPU 上:
```
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
```
阅读全文