(query @ key.transpose(-2, -1)) * (query.shape[-1]) ** -0.5
时间: 2023-10-21 14:06:13 浏览: 44
考虑一个三变量函数(见下方代码),寻找这个函数的在(0.5, 0.5, 0.5)和(-0.5, -0.5, -0.5)附近的两个
这是一个使用自注意力机制计算查询-键相似度的公式。其中,query 表示查询向量,key 表示键向量,transpose(-2, -1) 表示将 key 的倒数第一维和倒数第二维交换位置,使得它们可以进行矩阵乘法。乘上 (query.shape[-1]) ** -0.5 是为了缩放相似度,以避免它们过大或过小。最终的结果是一个形状为 (batch_size, num_heads, query_len, key_len) 的相似度矩阵,它将在后续计算中被用作注意力权重。
阅读全文