pytorch搭建lstm
时间: 2023-10-29 16:06:33 浏览: 112
project2_pytorch实现lstm_
5星 · 资源好评率100%
下面是一个简单的例子,展示如何使用PyTorch搭建LSTM模型:
```python
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMModel, self).__init__()
# 隐藏层维度
self.hidden_dim = hidden_dim
# LSTM层数
self.layer_dim = layer_dim
# LSTM层
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
# 输出层
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# 初始化隐藏层状态
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# 初始化记忆单元状态
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# 前向传播 LSTM
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
# 取最后一个时间步的输出
out = self.fc(out[:, -1, :])
return out
```
在这个例子中,我们定义了一个LSTMModel类,它继承自nn.Module。在构造函数中,我们定义了LSTM的输入维度,隐藏层维度、LSTM层数以及输出维度。接下来,我们定义了LSTM层和输出层。在forward函数中,我们首先初始化了隐藏层状态和记忆单元状态,然后调用LSTM层进行前向传播,最后取最后一个时间步的输出并通过输出层得到最终的输出。
阅读全文