torch.nn.LSTM()
时间: 2023-09-10 11:07:32 浏览: 31
torch.nn.LSTM() 是 PyTorch 深度学习库中的一个模块,用于实现长短时记忆网络(Long Short-Term Memory,LSTM)的前向传播过程。LSTM 是一种常用的循环神经网络,通常用于处理序列数据,如文本、语音等。
LSTM 能够有效地解决循环神经网络中的梯度消失和梯度爆炸问题,通过引入门控机制,控制信息的流动和保留,实现了长期依赖关系的建模。
torch.nn.LSTM() 中需要传入的参数包括输入的特征维度、隐藏层的维度、LSTM 的层数、是否使用双向 LSTM 等。在调用 forward() 方法时,需要传入当前时间步的输入数据以及上一个时间步的隐藏状态和记忆单元,LSTM 将返回当前时间步的输出、隐藏状态和记忆单元。
示例代码:
```python
import torch.nn as nn
# 输入特征维度为 10,隐藏层维度为 20,LSTM 层数为 2
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
# 输入数据 shape 为 (seq_len, batch_size, input_size)
input_data = torch.randn(5, 3, 10)
# 初始化隐藏状态和记忆单元
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
# 前向传播
output, (hn, cn) = lstm(input_data, (h0, c0))
print(output.shape) # 输出 shape 为 (seq_len, batch_size, hidden_size)
print(hn.shape) # 隐藏状态 shape 为 (num_layers * num_directions, batch_size, hidden_size)
print(cn.shape) # 记忆单元 shape 为 (num_layers * num_directions, batch_size, hidden_size)
```