注意力机制有几种?分别给出详细介绍以及各自代表性的简单Python代码
时间: 2024-09-26 13:04:31 浏览: 4
注意力机制主要有几种常见的变体,包括基于点的注意力(Point-wise Attention)、基于序列的自注意力(Self-Attention,如Transformer)、以及混合注意力(Muti-head Attention)等。
1. **基于点的注意力(Point-wise Attention)**:这是最简单的形式,通常用于循环神经网络(RNNs)中。例如,在LSTM或GRU中,每个时间步会单独计算其前向和后向隐藏状态的注意力。Python代码示例(简化版):
```python
class PointWiseAttention(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
def forward(self, query, keys, values):
energy = self.attn(torch.cat((query, keys), dim=2)).squeeze(-1)
weights = F.softmax(energy, dim=1)
context = torch.bmm(weights.unsqueeze(1), values).squeeze(1)
return context
```
2. **基于序列的自注意力(Self-Attention)**:比如Transformer中的自注意力模块,它同时考虑了输入序列的所有位置。每个查询向量都会与其他所有键值对对应并计算注意力分数,再进行加权求和得到上下文向量。Python代码示例(简化版):
```python
from transformers import MultiHeadSelfAttention
class SelfAttentionLayer(MultiHeadSelfAttention):
# 使用Transformer库中的MultiHeadSelfAttention组件
def forward(self, input_tensor, attention_mask=None):
return super().forward(input_tensor, attn_mask=attention_mask)
```
3. **混合注意力(Muti-head Attention)**:为了提高性能和表达能力,Transformer引入了多头注意力,即同时执行多个独立的注意力子层,每个子层专注于输入的不同方面。Python代码示例(简化版):
```python
class MultiHeadAttentionLayer(MultiheadAttention):
def forward(self, queries, keys, values, mask=None):
return super().forward(queries, keys, values, attn_mask=mask)
```
每种注意力机制都有其特定的应用场景和优缺点,选择哪种取决于具体的任务需求。