写一段代码用pytorch训练RNN模型
时间: 2024-01-26 13:01:47 浏览: 28
以下是使用 PyTorch 训练 RNN 模型的一个简单示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 RNN 模型
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNNModel, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, h0):
out, hn = self.rnn(x, h0)
out = self.fc(out[:, -1, :])
return out, hn
# 定义训练函数
def train(model, optimizer, criterion, inputs, targets, epochs):
for epoch in range(epochs):
optimizer.zero_grad()
h0 = torch.zeros(1, inputs.size(0), model.hidden_size)
outputs, hn = model(inputs, h0)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# 定义训练数据和目标数据
inputs = torch.randn(10, 1, 5)
targets = torch.randn(10, 1)
# 定义模型、优化器和损失函数
model = RNNModel(5, 10, 1)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 开始训练
train(model, optimizer, criterion, inputs, targets, epochs=100)
```
在上面的代码中,我们首先定义了一个简单的 RNN 模型,并使用它来训练随机生成的输入和目标数据。我们使用均方误差损失函数来评估模型的性能,并使用 Adam 优化器来更新模型参数。在训练过程中,我们使用 PyTorch 提供的自动求导功能来计算损失函数关于模型参数的梯度,并使用优化器来更新模型参数。