多头自注意力机制的计算公式
时间: 2023-11-13 09:17:21 浏览: 81
多头自注意力机制的计算公式如下:
假设输入序列为 $X = [x_1, x_2, \dots, x_n]$,其中 $x_i \in \mathbb{R}^d$,多头注意力头数为 $h$,每个头的注意力权重为 $w_i^h$,则多头自注意力机制的输出为:
$$
\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h) W^O
$$
其中,$W^O \in \mathbb{R}^{hd \times d}$ 是输出矩阵,$\text{Concat}$ 表示将多个头的输出拼接在一起。每个头的输出 $\text{head}_h$ 的计算公式为:
$$
\text{head}_h = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)
$$
其中,$W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k}$ 分别是查询、键、值的权重矩阵,$d_k = \frac{d}{h}$ 是每个头的维度。
注意力权重 $w_i^h$ 的计算公式为:
$$
w_i^h = \frac{\exp(e_i^h)}{\sum_{j=1}^n \exp(e_j^h)}
$$
其中,$e_i^h$ 表示第 $h$ 个头中第 $i$ 个位置与其他位置的相似度,计算公式为:
$$
e_i^h = \frac{(XW_i^Q)_i (XW_i^K)_i^T}{\sqrt{d_k}}
$$
注意力的输出 $\text{Attention}(Q,K,V)$ 的计算公式为:
$$
\text{Attention}(Q,K,V) = \text{Softmax}(QK^T/\sqrt{d_k})V
$$
其中,$Q, K, V$ 分别是查询、键、值,$\text{Softmax}$ 表示对每一行进行 softmax 操作。
阅读全文