借助pytorch实现GRU代码
时间: 2023-12-19 20:05:45 浏览: 126
好的,以下是一个使用PyTorch实现GRU的代码示例:
```python
import torch
import torch.nn as nn
class GRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, h):
out, h = self.gru(x, h)
out = self.fc(out[:, -1, :])
return out, h
def init_hidden(self, batch_size):
return torch.zeros(self.num_layers, batch_size, self.hidden_size)
# 使用示例
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 1
batch_size = 32
seq_len = 5
model = GRU(input_size, hidden_size, num_layers, output_size)
x = torch.randn(batch_size, seq_len, input_size)
h = model.init_hidden(batch_size)
out, h = model(x, h)
print(out.shape) # torch.Size([32, 1])
```
这是一个简单的GRU模型,包含一个GRU层和一个全连接层。在`forward`函数中,我们首先通过GRU层获取输出和最后一个时刻的隐状态,然后将最后一个时刻的输出通过全连接层得到最终的输出。`init_hidden`函数用于初始化隐状态。
阅读全文