LSTM PYTHON代码
时间: 2023-11-09 11:54:07 浏览: 131
LSTM(长短期记忆)是一种用于处理序列数据的人工神经网络。相较于传统的循环神经网络(RNN),LSTM通过引入门控机制,解决了梯度消失和梯度爆炸的问题,使得网络能够有效地处理长序列数据。
下面是一个简单的LSTM的Python代码示例:
```python
import torch
import torch.nn as nn
# 定义LSTM模型
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# 定义输入和参数
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 1
batch_size = 16
seq_length = 5
# 随机生成输入数据
x = torch.randn(batch_size, seq_length, input_size)
# 创建LSTM模型实例
lstm_model = LSTMModel(input_size, hidden_size, num_layers, output_size)
# 前向传播
output = lstm_model(x)
# 打印输出
print(output.shape)
```
阅读全文