pytorch实现多头自注意力机制
时间: 2023-07-24 11:15:58 浏览: 209
Sure! 下面是使用 PyTorch 实现多头自注意力机制(Multi-Head Self-Attention)的示例代码:
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换得到 Q、K、V
query = self.query_linear(query)
key = self.key_linear(key)
value = self.value_linear(value)
# 将 Q、K、V 分割为多个头部
query = query.view(batch_size * self.num_heads, -1, self.d_k)
key = key.view(batch_size * self.num_heads, -1, self.d_k)
value = value.view(batch_size * self.num_heads, -1, self.d_k)
# 计算注意力得分
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 注意力权重归一化
attention_weights = torch.softmax(scores, dim=-1)
# 加权求和得到输出
output = torch.matmul(attention_weights, value)
# 将多个头部的输出拼接起来
output = output.view(batch_size, -1, self.num_heads * self.d_k)
# 线性变换得到最终输出
output = self.output_linear(output)
return output, attention_weights
```
这是一个简单的多头自注意力机制的实现,其中 `d_model` 表示输入和输出的特征维度,`num_heads` 表示头的数量。在 `forward` 方法中,首先通过线性变换得到 Q、K、V,然后将它们分割为多个头部,接着计算注意力得分并进行归一化,最后将多个头部的输出拼接起来并进行线性变换得到最终输出。注意,这里的实现没有考虑掩码和位置编码,你可以根据具体需求进行相应的调整和扩展。
阅读全文