BiLSTM的Python代码
时间: 2023-11-05 16:19:50 浏览: 104
下面是一个简单的使用BiLSTM模型的Python代码示例:
```python
import torch
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 初始化隐藏状态h0
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 初始化细胞状态c0
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :]) # 获取最后一个时间步的输出
return out
# 设置超参数
input_size = 28 # 输入特征维度
hidden_size = 128 # LSTM隐藏层大小
num_layers = 2 # LSTM层数
num_classes = 10 # 分类类别数
# 构建模型
model = BiLSTM(input_size, hidden_size, num_layers, num_classes)
# 预测样例
sample_input = torch.randn(1, 28, 28) # 输入样例
outputs = model(sample_input)
print(outputs)
```
这段代码定义了一个BiLSTM模型,用于进行文本分类等任务。通过`forward`方法定义了前向传播过程,使用`nn.LSTM`构建双向LSTM层,并使用线性层`nn.Linear`将LSTM输出映射到类别数目上。最后,我们可以通过输入一个样例数据来进行预测。请注意,这只是一个简单的示例,并且可能需要根据实际任务进行调整和修改。
阅读全文