给出一个pytorch的rnn训练
时间: 2023-10-03 08:09:28 浏览: 45
以下是一个基本的PyTorch RNN训练示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
# 定义RNN模型
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def init_hidden(self):
return Variable(torch.zeros(1, self.hidden_size))
# 定义训练数据
input_sequence = ['hello', 'world', 'how', 'are', 'you']
output_sequence = ['h', 'w', 'h', 'a', 'y']
# 定义超参数
input_size = 5
hidden_size = 10
output_size = 1
learning_rate = 0.1
epochs = 1000
# 初始化模型和优化器
rnn = RNN(input_size, hidden_size, output_size)
optimizer = optim.SGD(rnn.parameters(), lr=learning_rate)
loss_function = nn.NLLLoss()
# 开始训练
for epoch in range(epochs):
hidden = rnn.init_hidden()
rnn.zero_grad()
loss = 0
for i in range(len(input_sequence)):
input_tensor = Variable(torch.FloatTensor([input_sequence[i]]))
target_tensor = Variable(torch.LongTensor([ord(output_sequence[i]) - 97]))
output, hidden = rnn(input_tensor, hidden)
loss += loss_function(output, target_tensor)
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch: {}, Loss: {}'.format(epoch, loss.item()))
# 测试模型
test_sequence = ['hi', 'there', 'how', 'is', 'it']
for i in range(len(test_sequence)):
input_tensor = Variable(torch.FloatTensor([test_sequence[i]]))
hidden = rnn.init_hidden()
output, hidden = rnn(input_tensor, hidden)
print('{} -> {}'.format(test_sequence[i], chr(torch.max(output.data, 1)[1][0] + 97)))
```
该代码示例定义了一个简单的RNN模型,用于预测给定单词的第一个字母。在训练过程中,我们使用随机梯度下降优化器和负对数似然损失函数来训练模型。我们使用输入序列“hello”,“world”,“how”,“are”,“you”和输出序列“h”,“w”,“h”,“a”,“y”来训练模型。在训练结束后,我们使用模型来预测测试序列“hi”,“there”,“how”,“is”,“it”的第一个字母。