软注意力机制复杂超长超长超长版代码
时间: 2023-07-01 13:26:38 浏览: 89
以下是一个更加复杂的软注意力机制的PyTorch实现,包括了超长超长超长版的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ComplexSoftAttention(nn.Module):
def __init__(self, in_dim, hidden_dim=None):
super(ComplexSoftAttention, self).__init__()
if hidden_dim is None:
hidden_dim = in_dim
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.query_layer = nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=False)
self.key_layer = nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=False)
self.value_layer = nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=False)
def forward(self, query, key, value, mask=None):
"""
:param query: [batch_size, query_len, in_dim]
:param key: [batch_size, key_len, in_dim]
:param value: [batch_size, key_len, in_dim]
:param mask: [batch_size, key_len]
:return: [batch_size, query_len, hidden_dim]
"""
query = self.query_layer(query) # [batch_size, query_len, hidden_dim]
key = self.key_layer(key) # [batch_size, key_len, hidden_dim]
value = self.value_layer(value) # [batch_size, key_len, hidden_dim]
# 计算复数内积
real_query = query[:, :, :self.hidden_dim // 2]
imag_query = query[:, :, self.hidden_dim // 2:]
real_key = key[:, :, :self.hidden_dim // 2]
imag_key = key[:, :, self.hidden_dim // 2:]
real_value = value[:, :, :self.hidden_dim // 2]
imag_value = value[:, :, self.hidden_dim // 2:]
real_attention_weights = torch.einsum('bqi,bqj->biqj',
real_query.view(query.size(0), query.size(1), 1, self.hidden_dim // 2),
real_key.view(key.size(0), 1, key.size(1), self.hidden_dim // 2)) \
- torch.einsum('bqi,bqj->biqj',
imag_query.view(query.size(0), query.size(1), 1, self.hidden_dim // 2),
imag_key.view(key.size(0), 1, key.size(1), self.hidden_dim // 2))
imag_attention_weights = torch.einsum('bqi,bqj->biqj',
real_query.view(query.size(0), query.size(1), 1, self.hidden_dim // 2),
imag_key.view(key.size(0), 1, key.size(1), self.hidden_dim // 2)) \
+ torch.einsum('bqi,bqj->biqj',
imag_query.view(query.size(0), query.size(1), 1, self.hidden_dim // 2),
real_key.view(key.size(0), 1, key.size(1), self.hidden_dim // 2))
attention_weights = torch.cat([real_attention_weights, imag_attention_weights], dim=-1) # [batch_size, query_len, key_len, 2]
# 使用掩码处理注意力权重
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(-1).repeat([1, query.size(1), key.size(1), 2]) # [batch_size, query_len, key_len, 2]
attention_weights = attention_weights.masked_fill(mask == 0, -1e9)
# 计算复杂软注意力得分
attention_weights = F.softmax(attention_weights, dim=2) # [batch_size, query_len, key_len, 2]
real_attention_weights = attention_weights[:, :, :, 0]
imag_attention_weights = attention_weights[:, :, :, 1]
attention_weights = real_attention_weights + 1j * imag_attention_weights # [batch_size, query_len, key_len]
attention_weights = torch.div(attention_weights, torch.sqrt(torch.tensor(self.hidden_dim, dtype=torch.float32))) # [batch_size, query_len, key_len]
# 加权求和
output = torch.einsum('bqi,bqj->bij',
attention_weights,
torch.cat([real_value, imag_value], dim=-1)) # [batch_size, query_len, hidden_dim]
return output
```
这个模型的核心也是复杂软注意力机制,但是相较于之前的实现,它进行了更加严格的复数计算和掩码处理。在这个实现中,我们将输入的实数和虚数分别拆分成两个部分,然后使用PyTorch的einsum函数计算复数内积。在计算复杂软注意力得分时,我们首先对得到的实部和虚部分别进行softmax计算,然后将它们合并成一个复数向量,并且进行了归一化和平方根缩放。最后,我们使用PyTorch的einsum函数将加权后的值重新组合成一个张量。
阅读全文