pytorch lstm如何保存训练好的模型
时间: 2023-02-28 12:06:23 浏览: 262
在 PyTorch 中,保存训练好的 LSTM 模型需要以下步骤:
1. 实例化一个 torch.nn.Module 子类,用于定义模型的结构,并将模型的参数初始化。
2. 使用数据训练模型。
3. 使用 torch.save 函数将模型保存到磁盘上。
代码示例:
```python
import torch
import torch.nn as nn
# 定义 LSTM 模型
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
def forward(self, x):
out, _ = self.lstm(x)
return out
# 初始化模型
model = LSTM(input_size=3, hidden_size=5, num_layers=2)
# 使用数据训练模型
# ...
# 保存模型
torch.save(model.state_dict(), "lstm_model.pth")
```
在需要加载模型时,可以使用 torch.load 函数:
```python
# 加载模型
model = LSTM(input_size=3, hidden_size=5, num_layers=2)
model.load_state_dict(torch.load("lstm_model.pth"))
```
阅读全文