class RnnModel(nn.Module)
时间: 2024-02-22 22:38:25 浏览: 48
浅析PyTorch中nn.Module的使用
5星 · 资源好评率100%
class RnnModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RnnModel, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, input):
batch_size = input.size(0)
hidden = self.init_hidden(batch_size)
output, hidden = self.rnn(input, hidden)
output = output[-1, :, :]
output = self.fc(output)
return output
def init_hidden(self, batch_size):
hidden = torch.zeros(1, batch_size, self.hidden_size)
return hidden
阅读全文