写一段BiLSTM的代码
时间: 2023-12-01 10:43:17 浏览: 89
Sure, here's an example of BiLSTM code using PyTorch:
```python
import torch
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(BiLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, output_dim) # *2 for bidirectional
def forward(self, x):
lstm_out, _ = self.lstm(x)
out = self.fc(lstm_out[:, -1, :]) # take the last time step
return out
```
In this code snippet, we define a `BiLSTM` class which inherits from `nn.Module`. The constructor takes the input dimension, hidden dimension, and output dimension as arguments. Inside the constructor, we define the layers of the network: an LSTM layer and a fully connected layer.
In the `forward` method, we pass the input `x` through the LSTM layer. The output of the LSTM is then passed through the fully connected layer, and the final output is returned.
Note that in this example, we assume the input `x` has a shape of `(sequence_length, batch_size, input_dim)`. The output will have a shape of `(batch_size, output_dim)`.
I hope this helps! Let me know if you have any more questions.
阅读全文