pytorch的RNN代码示例
时间: 2024-07-21 12:01:14 浏览: 73
PyTorch是一个流行的深度学习框架,其中循环神经网络(Recurrent Neural Network, RNN)是一种用于处理序列数据的重要模型。下面是一个简单的LSTM(Long Short-Term Memory)网络的代码示例,它在PyTorch中用于文本生成任务:
```python
import torch
from torch import nn
# 定义超参数
input_size = 100 # 输入特征维度
hidden_size = 50 # 隐藏层单元数
num_layers = 2 # 双向RNN的层数
seq_length = 20 # 序列长度
output_size = 10 # 输出维数 (这里假设每个字符有一个10维的one-hot编码)
# 创建LSTM模型
class RNNModel(nn.Module):
def __init__(self):
super(RNNModel, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size * 2, output_size) # 注意这里是双向隐藏层大小
def forward(self, x):
# 使用LSTM处理输入
out, _ = self.rnn(x)
# 将最后一个时间步的上下文连接到全连接层
out = self.fc(out[:, -1, :])
return out
model = RNNModel()
# 假设我们有输入数据 x (形状(batch_size, seq_length, input_size))
# 你需要定义一个数据加载器和优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs): # 这里num_epochs是你希望训练的轮数
for inputs in your_data_loader:
optimizer.zero_grad() # 梯度清零
outputs = model(inputs) # 计算预测值
loss = criterion(outputs, targets) # 使用交叉熵损失函数
loss.backward() # 反向传播
optimizer.step() # 更新权重
阅读全文