如何用nn.Embedding和Bilstm提取文本信息,请用pytorch写一下代码?
时间: 2023-11-28 10:49:59 浏览: 117
以下是用 PyTorch 实现使用 nn.Embedding 和 BiLSTM 提取文本信息的代码示例:
```python
import torch
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(BiLSTM, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.bilstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
def forward(self, x):
x = self.embedding(x) # (batch_size, seq_len, embedding_dim)
h, _ = self.bilstm(x) # (batch_size, seq_len, 2*hidden_dim)
return h
# 示例
vocab_size = 10000
embedding_dim = 128
hidden_dim = 64
batch_size = 32
seq_len = 50
input_data = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) # (batch_size, seq_len)
model = BiLSTM(vocab_size, embedding_dim, hidden_dim)
output_data = model(input_data) # (batch_size, seq_len, 2*hidden_dim)
```
其中,`nn.Embedding` 用于将每个 token 转换为其对应的向量表示,`nn.LSTM` 则用于提取序列信息,同时由于 BiLSTM 的使用,输出的向量维度为 `2*hidden_dim`。最终输出的结果为 `(batch_size, seq_len, 2*hidden_dim)`。
阅读全文