pytorch建gru
时间: 2023-12-30 18:24:46 浏览: 74
以下是使用PyTorch建立GRU模型的示例代码:
```python
import torch
import torch.nn as nn
class GRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(GRU, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x, h0)
return out
# 创建GRU模型实例
input_size = 10
hidden_size = 20
num_layers = 2
gru_model = GRU(input_size, hidden_size, num_layers)
# 输入数据
batch_size = 3
seq_length = 5
input_data = torch.randn(batch_size, seq_length, input_size)
# 前向传播
output = gru_model(input_data)
print(output.shape) # 输出:torch.Size([3, 5, 20])
```
阅读全文