transformer 自注意力图
时间: 2025-01-08 15:10:35 浏览: 11
Transformer 自注意力机制工作原理
在Transformer架构中,自注意力机制允许模型不同位置的输入相互关联,从而捕捉序列中的依赖关系。具体来说,在计算过程中,每个位置都会生成三个向量:查询(Query)、键(Key)以及值(Value)。这些向量用于衡量各个部分之间的相似度并加权求和得到最终输出。
对于给定的位置i, 查询Qi会与其他所有位置j产生的键Kj进行比较,以此来决定应赋予多少权重wj给对应的Vj。此过程可以形式化描述如下:
[ \text{Attention}(Q,K,V)=\text{softmax}\left(\frac{Q K^{T}}{\sqrt {d_k}}\right)V ]
其中( d_k )代表键维度大小[^1]。
这种设计使得即使距离较远的信息也能被有效利用起来,而不仅仅局限于局部上下文环境内。此外,多头注意力机制进一步增强了表达能力,它通过多个平行运行的关注子空间实现更丰富的特征提取。
可视化表示
为了更好地理解上述抽象概念,可以通过图形界面展示来自不同层的数据流动情况。例如,在处理一句话时,可以看到词语间如何建立联系,并观察到哪些词对特定预测产生了较大影响。这类工具通常提供交互功能让用户探索内部运作细节,包括但不限于查看各节点间的连接强度分布图谱等[^2]。
import torch
import matplotlib.pyplot as plt
def plot_attention_map(attention_weights):
"""
绘制注意力权重热力图
参数:
attention_weights (torch.Tensor): 注意力权重矩阵
"""
fig, ax = plt.subplots()
im = ax.imshow(attention_weights.detach().numpy(), cmap='viridis')
# 添加颜色条
cbar = ax.figure.colorbar(im, ax=ax)
ax.set_xticks(np.arange(len(attention_weights)))
ax.set_yticks(np.arange(len(attention_weights)))
# 显示数值标签
for i in range(len(attention_weights)):
for j in range(len(attention_weights)):
text = ax.text(j, i, f"{attention_weights[i][j]:.2f}",
ha="center", va="center", color="w")
ax.set_title("Self-Attention Weights Heatmap")
plt.show()
# 假设有一个形状为 [seq_len, seq_len] 的张量作为注意力建模的结果
fake_attention_weights = torch.rand((8, 8))
plot_attention_map(fake_attention_weights)