nn.LSTM使用的例子
时间: 2023-11-10 13:50:59 浏览: 42
下面是一个使用PyTorch中的nn.LSTM模块训练一个文本分类模型的例子:
```python
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(LSTMClassifier, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, input_seq):
# 初始化LSTM的隐藏状态
h0 = torch.zeros(1, input_seq.size(1), self.hidden_dim)
c0 = torch.zeros(1, input_seq.size(1), self.hidden_dim)
# 将输入序列传入LSTM模块
output, (hidden_state, cell_state) = self.lstm(input_seq, (h0, c0))
# 取最后一个时间步的输出作为分类器的输入
output = output[-1, :, :]
# 将LSTM的输出传入全连接层
output = self.fc(output)
return output
```
在这个例子中,我们定义了一个LSTMClassifier类,它包含一个nn.LSTM模块和一个nn.Linear模块,用于将LSTM的输出转换为分类结果。在forward方法中,我们首先使用torch.zeros方法初始化LSTM的隐藏状态和记忆状态,然后将输入序列传入LSTM模块。由于这个例子中的任务是文本分类,因此我们只需要取最后一个时间步的输出作为分类器的输入。最后,我们将LSTM的输出传入全连接层,得到最终的分类结果。