LSTM的torch.nn实现
时间: 2023-09-14 20:12:16 浏览: 103
Pythorch中torch.nn.LSTM()参数详解
5星 · 资源好评率100%
LSTM(Long Short-Term Memory)是一种循环神经网络(RNN)的变体,用于处理序列数据和处理长期依赖关系。在PyTorch中,LSTM可以通过torch.nn模块进行实现。
下面是一个使用torch.nn.LSTM实现LSTM模型的简单示例:
```python
import torch
import torch.nn as nn
# 定义LSTM模型
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# 设置超参数
input_size = 10
hidden_size = 32
num_layers = 2
num_classes = 2
sequence_length = 20
batch_size = 16
# 随机生成输入数据
x = torch.randn(batch_size, sequence_length, input_size).to(device)
# 创建LSTM模型实例
model = LSTMModel(input_size, hidden_size, num_layers, num_classes).to(device)
# 运行前向传播
outputs = model(x)
print(outputs.shape) # 输出形状为(batch_size, num_classes)
```
在这个示例中,我们定义了一个名为LSTMModel的自定义模型类,继承自nn.Module。模型的核心部分是一个nn.LSTM层和一个全连接层nn.Linear。在forward方法中,我们首先初始化LSTM的初始隐藏状态h0和内部状态c0,然后将输入数据x传入LSTM层,最后将LSTM最后一个时间步的输出经过全连接层得到最终预测结果。
你可以根据自己的需求调整模型的输入大小、隐藏状态大小、层数和输出类别数,并通过传入适当的输入数据运行前向传播。
阅读全文