全局注意力机制代码实现
时间: 2023-08-29 08:11:32 浏览: 70
全局注意力机制(Global Attention Mechanism)常用于序列到序列(sequence-to-sequence)模型中,用于对输入序列的每个元素进行加权求和。下面是一个简单的全局注意力机制的代码实现示例:
```python
import torch
import torch.nn as nn
class GlobalAttention(nn.Module):
def __init__(self, hidden_size):
super(GlobalAttention, self).__init__()
self.hidden_size = hidden_size
self.linear_in = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear_out = nn.Linear(hidden_size * 2, hidden_size, bias=False)
def forward(self, decoder_hidden, encoder_outputs):
# decoder_hidden: [batch_size, hidden_size]
# encoder_outputs: [seq_len, batch_size, hidden_size]
# Compute attention scores
seq_len = encoder_outputs.size(0)
decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) # [batch_size, seq_len, hidden_size]
energy = torch.tanh(self.linear_in(decoder_hidden + encoder_outputs)) # [seq_len, batch_size, hidden_size]
attention_scores = torch.sum(energy, dim=2) # [seq_len, batch_size]
# Compute attention weights
attention_weights = torch.softmax(attention_scores, dim=0) # [seq_len, batch_size]
# Compute context vector
context_vector = torch.bmm(attention_weights.unsqueeze(0), encoder_outputs.transpose(0, 1)) # [1, batch_size, hidden_size]
context_vector = context_vector.squeeze(0) # [batch_size, hidden_size]
# Concatenate context vector and decoder hidden state
output = torch.tanh(self.linear_out(torch.cat((context_vector, decoder_hidden.squeeze(1)), dim=1))) # [batch_size, hidden_size]
return output, attention_weights
```
以上代码是一个简单的全局注意力机制的实现示例,其中包括初始化方法和前向传播方法。通过输入解码器的隐藏状态和编码器的输出,计算注意力分数、注意力权重、上下文向量等,并返回最终的输出结果。
请注意,这只是一个简单的示例,实际应用中可能需要根据具体任务和模型结构进行适当的修改和调整。