写一个torch.nn.LSTM()代码
时间: 2024-02-26 07:58:34 浏览: 67
下面是一个简单的使用 torch.nn.LSTM() 创建 LSTM 层的代码示例:
```python
import torch
import torch.nn as nn
# 定义输入向量维度和隐藏状态维度
input_size = 10
hidden_size = 20
# 定义 LSTM 层
lstm_layer = nn.LSTM(input_size, hidden_size)
# 构造输入数据
input_seq = torch.randn(5, 3, 10) # 输入序列长度为 5,batch 大小为 3,输入向量维度为 10
# 初始化隐藏状态
h0 = torch.randn(1, 3, 20) # num_layers * num_directions, batch_size, hidden_size
c0 = torch.randn(1, 3, 20)
# 将输入序列和隐藏状态传入 LSTM 层
output, (hn, cn) = lstm_layer(input_seq, (h0, c0))
# 输出当前的隐藏状态和输出
print(output.shape) # torch.Size([5, 3, 20])
print(hn.shape) # torch.Size([1, 3, 20])
print(cn.shape) # torch.Size([1, 3, 20])
```
在上面的代码中,我们首先定义了输入向量的维度和隐藏状态的维度,然后使用 nn.LSTM() 创建了一个 LSTM 层。接着,我们构造了一个输入序列,其长度为 5,batch 大小为 3,输入向量维度为 10。然后,我们初始化了隐藏状态,将输入序列和隐藏状态传入 LSTM 层,最后输出当前的隐藏状态和输出。
阅读全文