lstm代码 pytorch
时间: 2024-09-24 11:09:10 浏览: 47
PyTorch案例:使用LSTM进行文本分类.zip
在PyTorch中,LSTM层的实现也非常直观,以下是创建一个简单的LSTM模型的例子:
```python
import torch
import torch.nn as nn
class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(SimpleLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # batch_first=True表示输入和输出的第一个维度是批量维度
self.fc = nn.Linear(hidden_size, output_size) # 全连接层用于输出
def forward(self, inputs):
h0 = torch.zeros(self.num_layers, inputs.size(0), self.hidden_size) # 初始化隐藏状态
c0 = torch.zeros(self.num_layers, inputs.size(0), self.hidden_size) # 初始化细胞状态
out, (hn, cn) = self.lstm(inputs, (h0, c0)) # 运行LSTM并获取最后时刻的隐藏状态和细胞状态
out = self.fc(out[:, -1, :]) # 取出最后时刻的输出作为整个序列的代表
return out
# 使用示例
input_size = 10 # 假设每个时间步有10个特征
hidden_size = 64
num_layers = 2
output_size = 1 # 输出一个值
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size)
inputs = torch.randn((batch_size, seq_length, input_size)) # 假设batch_size是实际批次大小,seq_length是序列长度
outputs = model(inputs)
```
阅读全文