多头注意力机制pytorch实现
时间: 2025-01-03 17:27:39 浏览: 5
### 实现多头注意力机制
在神经机器翻译和其他序列建模任务中,多头注意力机制允许模型关注输入的不同部分来改善性能[^1]。下面展示了一个基于PyTorch框架下实现多头注意力机制的方法。
```python
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# 定义线性变换层WQ, WK, WV 和WO
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(K.size(-1))
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_probs = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
batch_size, seq_len, _ = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
batch_size, _, seq_len, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * d_k)
def forward(self, query, key, value, mask=None):
Q = self.split_heads(self.W_q(query))
K = self.split_heads(self.W_k(key))
V = self.split_heads(self.W_v(value))
attn_output = self.scaled_dot_product_attention(Q, K, V, mask=mask)
combined_attn_output = self.combine_heads(attn_output)
final_output = self.W_o(combined_attn_output)
return final_output
```
此代码片段定义了`MultiHeadAttention`类,该类实现了多头自注意力机制的核心逻辑。通过将查询、键和值向量分割成多个头部并分别计算注意分数,可以捕捉到不同位置之间的关系模式。
阅读全文