用python实现两个特征的cross attention
时间: 2023-08-27 17:07:04 浏览: 458
以下是使用Python实现两个特征的交叉注意力的示例代码:
```python
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, feature_dim):
super(CrossAttention, self).__init__()
self.feature_dim = feature_dim
self.query_fc = nn.Linear(feature_dim, feature_dim, bias=False)
self.key_fc = nn.Linear(feature_dim, feature_dim, bias=False)
self.value_fc = nn.Linear(feature_dim, feature_dim, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(0.2)
def forward(self, feature1, feature2):
"""
Feature1: (batch_size, seq_len1, feature_dim)
Feature2: (batch_size, seq_len2, feature_dim)
"""
# Compute query, key, and value tensors for feature1
query1 = self.query_fc(feature1) # (batch_size, seq_len1, feature_dim)
key1 = self.key_fc(feature1) # (batch_size, seq_len1, feature_dim)
value1 = self.value_fc(feature1) # (batch_size, seq_len1, feature_dim)
# Compute query, key, and value tensors for feature2
query2 = self.query_fc(feature2) # (batch_size, seq_len2, feature_dim)
key2 = self.key_fc(feature2) # (batch_size, seq_len2, feature_dim)
value2 = self.value_fc(feature2) # (batch_size, seq_len2, feature_dim)
# Compute attention scores between feature1 and feature2
scores = torch.bmm(query1, key2.transpose(1, 2)) # (batch_size, seq_len1, seq_len2)
# Normalize attention scores using softmax
attn_weights = self.softmax(scores) # (batch_size, seq_len1, seq_len2)
# Apply dropout to attention weights
attn_weights = self.dropout(attn_weights)
# Compute the weighted sum of value2 using the attention weights
attended_feature2 = torch.bmm(attn_weights, value2) # (batch_size, seq_len1, feature_dim)
# Compute the weighted sum of value1 using the attention weights
attended_feature1 = torch.bmm(attn_weights.transpose(1, 2), value1) # (batch_size, seq_len2, feature_dim)
# Concatenate the attended features with the original features
feature1 = torch.cat([feature1, attended_feature2], dim=-1) # (batch_size, seq_len1, 2*feature_dim)
feature2 = torch.cat([feature2, attended_feature1], dim=-1) # (batch_size, seq_len2, 2*feature_dim)
return feature1, feature2
```
该代码实现了一个名为CrossAttention的PyTorch模块,它将两个特征作为输入,并计算它们之间的交叉注意力。具体来说,它首先使用三个全连接层将每个特征的每个时间步转换为查询(query)、键(key)和值(value)张量。然后,它计算了特征1和特征2之间的注意力得分,将其归一化,并使用dropout进行正则化。接下来,它使用注意力权重加权特征2的值张量,并使用加权的值张量计算特征1的加权和。反之亦然。最后,它将加权特征与原始特征连接在一起并返回它们。
您可以使用以下代码示例来测试CrossAttention模块:
```python
# Define the input features
feature1 = torch.randn(32, 10, 64) # (batch_size, seq_len1, feature_dim)
feature2 = torch.randn(32, 8, 64) # (batch_size, seq_len2, feature_dim)
# Create the CrossAttention module
cross_attn = CrossAttention(feature_dim=64)
# Apply CrossAttention to the input features
new_feature1, new_feature2 = cross_attn(feature1, feature2)
# Print the shapes of the output features
print(new_feature1.shape) # (32, 10, 128)
print(new_feature2.shape) # (32, 8, 128)
```
在这个例子中,我们使用随机生成的特征向量作为输入,并使用CrossAttention模块计算它们之间的交叉注意力。最后,我们打印输出特征的形状,以验证它们已正确计算。
阅读全文