交叉注意力机制 代码
时间: 2023-08-13 10:11:48 浏览: 211
以下是一个简单的交叉注意力机制的代码示例:
```python
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(CrossAttention, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.query_fc = nn.Linear(input_dim, hidden_dim)
self.key_fc = nn.Linear(input_dim, hidden_dim)
self.value_fc = nn.Linear(input_dim, hidden_dim)
def forward(self, query, key, value):
q = self.query_fc(query)
k = self.key_fc(key)
v = self.value_fc(value)
scores = torch.matmul(q, k.transpose(-2, -1))
attention_weights = torch.softmax(scores, dim=-1)
weighted_values = torch.matmul(attention_weights, v)
return weighted_values
# 使用示例
input_dim = 100
hidden_dim = 50
cross_attn = CrossAttention(input_dim, hidden_dim)
query = torch.randn((32, 10, input_dim)) # (batch_size, query_length, input_dim)
key = torch.randn((32, 20, input_dim)) # (batch_size, key_length, input_dim)
value = torch.randn((32, 20, input_dim)) # (batch_size, key_length, input_dim)
output = cross_attn(query, key, value)
print(output.shape) # 输出: torch.Size([32, 10, 50])
```
这是一个简单的交叉注意力机制,其中`CrossAttention`类包含了查询(query)、键(key)和值(value)的线性转换。在`forward`方法中,通过矩阵乘法计算注意力分数,然后使用 softmax 函数得到注意力权重。最后,将权重与值相乘得到加权值。以上代码仅为示例,实际应用中可能需要根据具体需求进行修改。
阅读全文