LSTM模型训练代码 pytorch
时间: 2023-08-08 15:10:14 浏览: 119
element-lstm-forecast-m笔记
好的,这里是一个简单的LSTM模型训练代码示例,使用PyTorch实现。
```
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 LSTM 模型
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, input):
output, (hidden, cell) = self.lstm(input)
output = self.fc(hidden.view(-1, self.hidden_size))
return output
# 准备数据
input_size = 10
hidden_size = 20
output_size = 1
seq_len = 5
batch_size = 3
input = torch.randn(seq_len, batch_size, input_size)
target = torch.randn(batch_size, output_size)
# 创建模型实例
model = LSTMModel(input_size, hidden_size, output_size)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
optimizer.zero_grad()
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')
```
在这个示例中,我们首先定义了一个LSTM模型,然后使用随机数据准备了训练数据集和测试数据集。接下来,我们定义了损失函数和优化器,并使用训练数据集训练了模型。最后,我们保存了训练好的模型参数。
阅读全文