自注意力机制的代码实现
时间: 2023-11-20 19:50:37 浏览: 54
自注意力机制(Self-Attention Mechanism)是一种用于处理序列数据的机制,它可以在不同位置之间建立关联,从而更好地捕捉序列中的长程依赖关系。下面是自注意力机制的代码实现:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(0.1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, seq_len, hidden_size = x.size()
# 将输入张量 x 沿着最后一个维度分割成 num_heads 个张量
# 每个张量的形状为 [batch_size, seq_len, head_size]
query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1,2)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32))
attn_weights = self.softmax(scores)
attn_weights = self.dropout(attn_weights)
# 计算加权和
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
return attn_output
```
上述代码实现了一个 SelfAttention 类,它包含了一个线性变换层(query、key、value)、一个 Dropout 层、一个 Softmax 层以及一个矩阵乘法操作。在 forward 方法中,我们首先将输入张量 x 沿着最后一个维度分割成 num_heads 个张量,然后分别进行线性变换,得到 query、key、value 张量。接着,我们计算注意力分数,通过 Softmax 层得到注意力权重,再通过 Dropout 层进行随机失活。最后,我们将注意力权重与 value 张量相乘并求和,得到最终的注意力输出。