pytorch GRU预测模型R² 代码
时间: 2023-09-23 15:13:30 浏览: 164
基于Pytorch实现GRU模型
以下是一个使用 PyTorch 中 GRU 模型进行预测并计算 R² 的示例代码:
```python
import torch
import torch.nn as nn
from sklearn.metrics import r2_score
# 定义 GRU 模型
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUModel, self).__init__()
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
_, h_n = self.gru(x)
out = self.fc(h_n.squeeze(0))
return out
# 定义训练数据和标签
train_data = torch.randn(100, 10, 1) # 100 个序列,每个序列有 10 个特征
train_label = torch.randn(100, 1)
# 定义模型、损失函数和优化器
model = GRUModel(1, 32, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
optimizer.zero_grad()
output = model(train_data)
loss = criterion(output, train_label)
loss.backward()
optimizer.step()
# 测试模型并计算 R²
test_data = torch.randn(50, 10, 1)
test_label = torch.randn(50, 1)
model.eval()
with torch.no_grad():
pred = model(test_data)
r2 = r2_score(test_label.numpy(), pred.numpy())
print("R²:", r2)
```
在上面的代码中,我们使用 PyTorch 中的 `nn.GRU` 实现了一个 GRU 模型,通过训练数据和标签进行训练,并使用 `sklearn.metrics.r2_score` 计算预测结果和真实标签之间的 R² 值。
阅读全文