BILSTM代码
时间: 2023-07-02 15:15:34 浏览: 126
以下是使用PyTorch构建的BILSTM模型的示例代码:
```
import torch
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(BiLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim*2, output_dim)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), self.hidden_dim).to(device=x.device)
c0 = torch.zeros(2, x.size(0), self.hidden_dim).to(device=x.device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
```
以上代码定义了一个名为“BiLSTM”的类,该类包含一个双向LSTM层和一个全连接层。在初始化函数中,我们指定了输入维度、隐藏维度和输出维度。在前向函数中,我们首先将输入的张量传递给LSTM层,然后提取LSTM输出的最后一个时间步,并将其传递给全连接层进行分类预测。该模型可以在文本分类、情感分析等任务中使用。
阅读全文