GRU模型推理python
时间: 2023-08-13 16:22:46 浏览: 178
在Python中,可以使用深度学习库如TensorFlow、PyTorch等来实现GRU模型的推理。以下是一个使用PyTorch实现GRU模型推理的示例代码:
```python
import torch
import torch.nn as nn
# 定义GRU模型
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUModel, self).__init__()
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.gru(x)
out = self.fc(out[:, -1, :])
return out
# 加载训练好的模型参数
model = GRUModel(input_size=10, hidden_size=20, output_size=2)
model.load_state_dict(torch.load('gru_model.pth'))
# 输入新数据
x = torch.randn(1, 5, 10) # 输入数据形状为[batch_size, seq_len, input_size]
# 执行推理过程
model.eval()
with torch.no_grad():
output = model(x)
_, predicted = torch.max(output.data, 1)
# 后处理输出结果
print(predicted.item())
```
在上面的示例中,我们首先定义了一个GRU模型,并加载了训练好的模型参数。然后,我们生成了一个随机的输入数据x,并将其输入到模型中进行推理。推理过程中,我们使用了PyTorch中的no_grad()上下文管理器来禁用梯度计算,以提高推理速度。最后,我们将模型的输出结果后处理成了一个标量值,并打印出来。
阅读全文