rnn代码pytorch
时间: 2023-10-22 20:10:47 浏览: 105
好的,以下是一个简单的RNN代码示例,使用PyTorch实现:
```python
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
# 示例用法
n_input = 10
n_hidden = 20
n_output = 5
rnn = RNN(n_input, n_hidden, n_output)
input_tensor = torch.randn(1, n_input)
hidden_tensor = rnn.initHidden()
output_tensor, next_hidden_tensor = rnn(input_tensor, hidden_tensor)
```
这个RNN模型接受一个输入张量和一个隐藏状态张量,并返回一个输出张量和下一个隐藏状态张量。在这个示例中,我们使用`LogSoftmax`作为输出层的激活函数,因此输出张量的每个元素都是对应类别的对数概率。
阅读全文