torch.save(model.state_dict(), file + ".pth")
时间: 2024-09-25 12:10:33 浏览: 50
pytorch 状态字典:state_dict使用详解
5星 · 资源好评率100%
当你在PyTorch中训练完模型并想要保存它的状态以便后续使用时,可以使用`torch.save()`函数。这个函数用于将模型的状态字典(`.state_dict()`)保存到磁盘上指定的文件(这里是一个".pth"格式的二进制文件)。`model.state_dict()`返回的是模型内部权重、梯度等参数的集合,而`file + ".pth"`则是你要存储该状态字典的目标文件路径。
例如,假设你有一个名为`model`的神经网络模型,你可以这样做:
```python
model = YourModel() # 假设YourModel是你定义的类
optimizer = torch.optim.Adam(model.parameters()) # 初始化优化器
# 训练过程...
# 在训练结束后保存模型状态
torch.save(model.state_dict(), "model.pth")
```
如果你想恢复这个模型,在以后加载时只需要加载状态字典,并设置为某个已初始化好的模型实例:
```python
loaded_model = YourModel() # 创建新的模型实例
loaded_model.load_state_dict(torch.load("model.pth"))
```
阅读全文