交叉跳跃连接代码示例
时间: 2024-10-16 10:22:21 浏览: 22
交叉跳跃连接(Cross-Attention,也称为自注意力或自我注意机制),是一种在深度学习特别是Transformer架构中常用的模块,它允许模型关注输入序列的不同部分之间的相互依赖。以下是基于PyTorch的一个简单例子,展示了如何在编码器层中实现交叉跳跃连接:
```python
import torch
from torch import nn
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
self.value_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, query, key, value):
batch_size = query.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
attention_weights = nn.Softmax(dim=-1)(attention_weights) # 归一化
context = torch.matmul(attention_weights, value)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.out_proj(context)
return output
# 示例用法
query = torch.randn(16, 64, 512)
key = value = query # 这里假设输入query、key和value相同
attention_layer = CrossAttentionLayer(512, 8)
output = attention_layer(query, key, value)
```
在这个例子中,`query`, `key`, 和 `value` 都是形状为 `(batch_size, sequence_length, d_model)` 的张量,表示输入序列的各个位置的信息。`forward` 函数计算了注意力权重并生成了上下文向量,最后通过`out_proj`将处理后的结果映射回原始维度。
阅读全文