BiLSTM-Attention
时间: 2025-01-22 10:12:38 浏览: 26
实现和应用 BiLSTM with Attention 机制
模型架构设计
在自然语言处理领域,BiLSTM结合Attention机制能够有效捕捉序列中的上下文信息以及重要部分。通过双向LSTM层可以获取输入序列正向和反向的信息流,而Attention机制则帮助模型聚焦于最相关的词或片段上[^1]。
对于具体实现而言,在构建网络时通常会先定义一个标准的双向LSTM单元来作为基础组件:
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5):
super(BiLSTM, self).__init__()
self.bilstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out, _ = self.bilstm(x)
return self.dropout(out)
接着引入Attention模块用于增强特征表示能力。这里采用的是加权求和的方式计算权重分布,并利用masked_softmax
函数确保不考虑填充位置的影响[^4]:
def masked_softmax(scores, mask=None):
if mask is not None:
scores = scores.masked_fill(~mask.unsqueeze(-1), float('-inf'))
alpha = F.softmax(scores, dim=-1)
return alpha
class AttentionLayer(nn.Module):
def __init__(self, hidden_dim):
super(AttentionLayer, self).__init__()
self.attn_w = nn.Linear(hidden_dim * 2, hidden_dim * 2)
self.tanh = nn.Tanh()
self.u_w = nn.Parameter(torch.randn(hidden_dim * 2))
def forward(self, hiddens, masks=None):
u = self.tanh(self.attn_w(hiddens))
score = torch.matmul(u, self.u_w).unsqueeze(dim=2)
attention_weights = masked_softmax(score, masks)
context_vector = (attention_weights * hiddens).sum(dim=1)
return context_vector, attention_weights.squeeze()
最后组合上述两部分形成完整的BiLSTM-Attention模型结构:
class BiLSTM_Attention(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
self.encoder = BiLSTM(embedding_dim, hidden_dim)
self.attention_layer = AttentionLayer(hidden_dim)
self.fc_out = nn.Linear(hidden_dim*2, output_dim)
def forward(self, text, text_lengths):
embedded = self.embedding(text)
packed_embedded = pack_padded_sequence(embedded, text_lengths.cpu(), enforce_sorted=False)
encoded_output, (_, _) = self.encoder(packed_embedded)
unpacked_output, _ = pad_packed_sequence(encoded_output)
attn_output, attention_weights = self.attention_layer(unpacked_output.permute(1, 0, 2))
logits = self.fc_out(attn_output)
return logits, attention_weights
此代码实现了带有注意力机制的双向长短期记忆网络(BiLSTM),适用于多种NLP任务如文本分类、情感分析等场景下的深度学习建模工作[^3]。
相关推荐


















