pytorch实现gru-attention时间序列预测代码
时间: 2023-08-27 09:07:06 浏览: 239
下面是一个使用 PyTorch 实现 GRU-Attention 时间序列预测的示代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# 数据集生成函数
def generate_data(n_samples, seq_length):
X = np.zeros((n_samples, seq_length))
Y = np.zeros((n_samples, seq_length))
for i in range(n_samples):
start = np.random.randint(0, 10)
X[i, 0] = start
for j in range(1, seq_length):
X[i, j] = X[i, j-1] + np.random.normal(0, 0.1)
Y[i, :] = X[i, :] + np.random.normal(0, 0.1, size=(seq_length,))
return X, Y
# 超参数
n_samples = 1000
seq_length = 20
input_dim = 1
hidden_dim = 32
output_dim = 1
learning_rate = 0.01
num_epochs = 100
# 数据集准备
X, Y = generate_data(n_samples, seq_length)
X = torch.from_numpy(X).float()
Y = torch.from_numpy(Y).float()
# 模型定义
class GRUAttention(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GRUAttention, self).__init__()
self.hidden_dim = hidden_dim
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.w1 = nn.Linear(hidden_dim, hidden_dim)
self.w2 = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# GRU 层
out, hidden = self.gru(x)
# 注意力层
attn_weights = torch.tanh(self.w1(out))
attn_weights = self.w2(attn_weights)
attn_weights = torch.softmax(attn_weights, dim=1)
attn_applied = torch.bmm(attn_weights.transpose(1,2), out)
# 全连接层
out = self.fc(attn_applied)
return out
# 模型实例化
model = GRUAttention(input_dim, hidden_dim, output_dim)
# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 模型训练
for epoch in range(num_epochs):
# 前向传播
y_pred = model(X)
# 计算损失
loss = criterion(y_pred, Y)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印损失
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 模型预测
model.eval()
with torch.no_grad():
y_pred = model(X)
# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(Y[0].numpy(), label='Ground Truth')
plt.plot(y_pred[0].numpy(), label='Predicted')
plt.legend()
plt.show()
```
这个示例代码使用了一个简单的数据集,包含 1000 个序列,每个序列长度为 20。模型使用了一个 GRU 层和一个注意力层,以及一个全连接层进行预测。最后,使用 PyTorch 的优化器进行训练,并绘制出预测结果。