pytorch框架,使用LSTM神经网络怎么训练一条曲线并保存训练模型,然后用训练的模型根据输入的前五个点预测出一条完整的曲线,用代码示例
时间: 2024-03-07 13:48:36 浏览: 85
好的,以下是一个使用 PyTorch 框架和 LSTM 网络训练一条曲线,并保存训练模型的代码示例:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子,保证每次运行结果一致
torch.manual_seed(1)
# 创建数据集,生成一条 sin 曲线
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
# 将数据集转化为 PyTorch 张量,并添加一个维度
x = torch.from_numpy(x_np[:, np.newaxis])
y = torch.from_numpy(y_np[:, np.newaxis])
# 定义 LSTM 网络
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, input):
h0 = torch.zeros(self.num_layers, input.size(1), self.hidden_size)
c0 = torch.zeros(self.num_layers, input.size(1), self.hidden_size)
out, _ = self.lstm(input, (h0, c0))
out = self.linear(out[-1])
return out
# 设置超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 2
learning_rate = 0.01
num_epochs = 1000
# 创建 LSTM 模型
lstm = LSTM(input_size, hidden_size, output_size, num_layers)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
# 开始训练
for epoch in range(num_epochs):
inputs = x
labels = y
# 前向传播
outputs = lstm(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 100 == 0:
print("Epoch [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs, loss.item()))
# 保存模型
torch.save(lstm.state_dict(), 'lstm_model.ckpt')
# 使用训练好的模型预测一条完整的曲线
lstm.eval()
with torch.no_grad():
future = 100
pred = x[-1]
preds = []
for i in range(future):
pred = lstm(pred.unsqueeze(0))
preds.append(pred.squeeze().item())
pred = pred.detach()
preds = np.array(preds)
# 绘制预测结果
x_np = x_np.reshape(-1)
y_np = y_np.reshape(-1)
plt.plot(steps, y_np, 'r', label='ground truth')
plt.plot(steps, x_np, 'b', label='input')
plt.plot(np.arange(steps[-1]+1, steps[-1]+future+1), preds, 'g', label='prediction')
plt.legend(loc='best')
plt.show()
```
以上代码将生成一条 sin 曲线,并使用 LSTM 网络进行训练和保存模型,最后使用训练好的模型预测一条完整的曲线并绘制出来。
阅读全文