gru时间序列预测pytorch代码
时间: 2023-10-23 21:09:31 浏览: 117
以下是使用PyTorch实现GRU进行时间序列预测的代码示例:
```python
import torch
import torch.nn as nn
class GRU(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.gru = nn.GRU(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, h0):
out, hn = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out
# 定义模型参数
input_size = 1
hidden_size = 32
output_size = 1
# 实例化模型
model = GRU(input_size, hidden_size, output_size)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
optimizer.zero_grad()
loss = 0
h0 = torch.zeros(1, 1, hidden_size)
for i in range(len(train_data) - 1):
x = train_data[i:i+1].unsqueeze(2).float()
y_true = train_data[i+1:i+2].float()
y_pred = model(x, h0)
loss += criterion(y_pred, y_true)
h0 = h0.detach()
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
# 预测未来的值
future = 30
preds = train_data[-1:].unsqueeze(2).float()
h0 = torch.zeros(1, 1, hidden_size)
for i in range(future):
with torch.no_grad():
pred = model(preds, h0)
preds = torch.cat([preds, pred.unsqueeze(2)], dim=0)
h0 = h0.detach()
# 可视化预测结果
import matplotlib.pyplot as plt
plt.plot(train_data.numpy())
plt.plot(range(len(train_data)-1, len(train_data)+future-1), preds.detach().numpy())
plt.show()
```
其中,`train_data`是训练数据,可以是一个PyTorch张量。在训练过程中,我们使用均方误差损失函数和Adam优化器。在预测时,我们使用训练好的模型来预测未来的值,并将预测结果可视化。
阅读全文