pytorch实现RNN
时间: 2023-08-01 12:10:17 浏览: 137
要使用PyTorch实现RNN,你需要导入`torch`和`torch.nn`模块。以下是一个简单的RNN实现的示例代码:
```python
import torch
import torch.nn as nn
# 定义RNN模型
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def init_hidden(self):
return torch.zeros(1, self.hidden_size)
# 示例数据
input_size = 10
hidden_size = 20
output_size = 5
# 创建模型实例
rnn = RNN(input_size, hidden_size, output_size)
# 输入数据
input = torch.randn(1, input_size)
hidden = torch.randn(1, hidden_size)
# 前向传播
output, next_hidden = rnn(input, hidden)
print(output)
```
这段代码定义了一个简单的RNN模型。你可以根据实际需求调整模型的输入维度、隐藏层维度和输出维度。在示例中,我们使用了`nn.Linear`作为输入到隐藏层和输出层的全连接层,并使用`nn.LogSoftmax`作为输出的激活函数。
阅读全文