展示一下self-attention和cross-attention的代码片段
时间: 2023-08-27 10:05:09 浏览: 104
基于深度学习的声纹识别(self-attention)
下面是一个简单的self-attention代码片段,用于对输入序列进行自注意力计算:
```
class SelfAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SelfAttention, self).__init__()
self.query_linear = nn.Linear(input_dim, hidden_dim)
self.key_linear = nn.Linear(input_dim, hidden_dim)
self.value_linear = nn.Linear(input_dim, hidden_dim)
def forward(self, x):
# 计算query、key、value向量
query = self.query_linear(x)
key = self.key_linear(x)
value = self.value_linear(x)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / math.sqrt(query.size(-1))
# 计算注意力权重
attention_weights = nn.functional.softmax(scores, dim=-1)
# 计算加权和
weighted_values = torch.matmul(attention_weights, value)
output = weighted_values.sum(dim=-2)
return output
```
下面是一个简单的cross-attention代码片段,用于计算query序列和key-value序列之间的跨注意力:
```
class CrossAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim, hidden_dim):
super(CrossAttention, self).__init__()
self.query_linear = nn.Linear(query_dim, hidden_dim)
self.key_linear = nn.Linear(key_dim, hidden_dim)
self.value_linear = nn.Linear(value_dim, hidden_dim)
def forward(self, query, key, value):
# 计算query、key、value向量
query = self.query_linear(query)
key = self.key_linear(key)
value = self.value_linear(value)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / math.sqrt(query.size(-1))
# 计算注意力权重
attention_weights = nn.functional.softmax(scores, dim=-1)
# 计算加权和
weighted_values = torch.matmul(attention_weights, value)
output = weighted_values.sum(dim=-2)
return output
```
阅读全文