参考代码:https://zhuanlan.zhihu.com/p/611684065
单头自注意力
def attention(query, key, value, mask=None, dropout=None):
"""Compute 'Scaled Dot Product Attention'"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
多头自注意力
class PrepareForMultiHeadAttention(nn.Module):
""" ## Prepare for multi-head attention """
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
self.heads = heads
self.d_k = d_k
def forward(self, x: torch.Tensor):
head_shape = x.shape[:-1]
x = self.linear(x)
# Split last dimension into heads
x = x.view(*head_shape, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads,
d_model]`
return x
class MultiHeadAttention(nn.Module):