attn_weights = Dot(axes=[3, 3])([input1, input1]) 与 attn_weights = Dot(axes=[2, 2])([input1, input1])的区别
时间: 2024-06-07 16:08:58 浏览: 125
flash-attn wheel
这两个操作的区别在于`axes`参数的不同。
`Dot(axes=[3, 3])([input1, input1])`表示在`input1`的最后两个维度上进行点积运算,即将`input1`的每个样本的最后一个维度与自身的最后一个维度进行点积,得到一个二维的注意力矩阵。这里的`input1`的形状应该是`(batch_size, seq_len, embedding_dim)`。
而`Dot(axes=[2, 2])([input1, input1])`表示在`input1`的第二个维度上进行点积运算,即将`input1`的每个样本的所有维度中的第二个维度(即词向量维度)分别与自身所有样本的第二个维度进行点积,得到一个二维的注意力矩阵。这里的`input1`的形状应该是`(batch_size, seq_len, embedding_dim)`。
因此,这两个操作得到的注意力矩阵的形状不同,具体实现取决于模型的需求。
阅读全文