GRU pytorch 代码
时间: 2024-09-24 13:01:21 浏览: 39
GRU (Gated Recurrent Unit)是一种门控循环单元,常用于处理序列数据,比如自然语言处理任务。在PyTorch中,你可以使用nn.GRU模块来创建一个GRU层。以下是一个简单的例子,展示如何创建一个单向GRU层:
```python
import torch
import torch.nn as nn
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1):
super(GRUModel, self).__init__()
# 初始化GRU层
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
def forward(self, inputs):
# 对输入应用GRU
outputs, _ = self.gru(inputs) # _ 表示隐藏状态,如果你需要也可以保存
return outputs
# 创建模型实例
input_size = 100
hidden_size = 50
model = GRUModel(input_size, hidden_size)
# 假设inputs是形状为(batch_size, seq_len, input_size)的张量
output = model(inputs)
```
在这个例子中,`inputs`是一个包含时间序列数据的批次张量,`forward()`函数会返回每个时间步后的隐藏状态。
阅读全文