GRU代码pytorch
时间: 2023-09-24 18:08:51 浏览: 45
Sure! Here's an example of how to implement a GRU (Gated Recurrent Unit) in PyTorch:
```python
import torch
import torch.nn as nn
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(GRUModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out
# Example usage:
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 5
batch_size = 3
sequence_length = 4
# Create dummy input
x = torch.randn(batch_size, sequence_length, input_size)
# Create an instance of the GRU model
model = GRUModel(input_size, hidden_size, num_layers, output_size)
# Forward pass
output = model(x)
print(output.shape)
```
In this example, we define a `GRUModel` class that inherits from `nn.Module`. It takes the input size, hidden size, number of layers, and output size as parameters. Inside the class, we define the GRU layer using `nn.GRU`, followed by a fully connected layer (`nn.Linear`) for the final output.
In the `forward` method, we initialize the hidden state (`h0`) with zeros and pass it to the GRU layer along with the input tensor `x`. The output is then passed through the fully connected layer to produce the final output.
Finally, we create an instance of the `GRUModel` class and perform a forward pass on a dummy input tensor to obtain the output.