写一个多头注意力机制的代码
时间: 2024-02-15 21:59:36 浏览: 75
在下面的代码中,我将展示一个简单的多头注意力机制的实现。假设我们有一个输入张量`input_tensor`,其形状为(batch_size, seq_length, hidden_size),其中`hidden_size`表示隐藏层的维度。我们将使用`num_heads`个注意力头来处理输入张量。
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.query_linear = nn.Linear(hidden_size, hidden_size)
self.key_linear = nn.Linear(hidden_size, hidden_size)
self.value_linear = nn.Linear(hidden_size, hidden_size)
self.output_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, input_tensor):
batch_size, seq_length, _ = input_tensor.size()
# 将输入张量通过线性层映射得到查询、键和值
query = self.query_linear(input_tensor)
key = self.key_linear(input_tensor)
value = self.value_linear(input_tensor)
# 将查询、键和值张量分割成多个头
query = query.view(batch_size, seq_length, self.num_heads, self.hidden_size // self.num_heads)
key = key.view(batch_size, seq_length, self.num_heads, self.hidden_size // self.num_heads)
value = value.view(batch_size, seq_length, self.num_heads, self.hidden_size // self.num_heads)
# 计算注意力得分
scores = torch.matmul(query, key.transpose(-2, -1)) / (self.hidden_size // self.num_heads)**0.5
# 对注意力得分进行softmax归一化
attention_weights = torch.softmax(scores, dim=-1)
# 使用注意力权重对值进行加权求和
weighted_values = torch.matmul(attention_weights, value)
# 将多个头的输出拼接在一起
weighted_values = weighted_values.view(batch_size, seq_length, self.hidden_size)
output = self.output_linear(weighted_values)
return output
```
这段代码定义了一个名为`MultiHeadAttention`的PyTorch模块,它接受一个隐藏层维度`hidden_size`和一个注意力头数`num_heads`作为参数。在`forward`方法中,我们首先通过线性层将输入张量映射为查询、键和值张量。然后,我们将这些张量分割成多个头,并计算注意力得分。接下来,我们对注意力得分进行softmax归一化,并使用注意力权重对值进行加权求和。最后,我们将多个头的输出拼接在一起,并通过线性层映射为最终的输出。
阅读全文