gru网络python实例
时间: 2023-09-09 15:12:04 浏览: 108
下面是一个简单的 GRU 网络实例,使用 PyTorch 实现:
```python
import torch
import torch.nn as nn
class GRUNet(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(GRUNet, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, h):
out, h = self.gru(x, h)
out = self.fc(out[:, -1, :])
return out, h
def init_hidden(self, batch_size):
return torch.zeros(self.num_layers, batch_size, self.hidden_size)
# 测试
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 5
batch_size = 3
seq_len = 4
x = torch.randn(batch_size, seq_len, input_size)
model = GRUNet(input_size, hidden_size, num_layers, output_size)
h = model.init_hidden(batch_size)
out, h = model(x, h)
print(out.shape) # (3, 5)
```
在上面的实例中,我们实现了一个两层的 GRU 网络,输入的数据维度是 `(batch_size, seq_len, input_size)`,输出的数据维度是 `(batch_size, output_size)`。在 `forward` 方法中,我们使用了 PyTorch 中内置的 GRU 层和线性层,`init_hidden` 方法用于初始化隐藏状态。在测试中,我们生成了一组随机的输入数据,并将其输入到 GRU 网络中,得到了相应的输出。
阅读全文