pytroch 多头注意力机制代码
时间: 2024-08-28 10:00:19 浏览: 72
PyTorch 中的多头注意力(Multi-Head Attention)通常是在处理序列数据的深度学习模型中,如Transformer架构中非常关键的一部分。这种机制允许模型同时关注输入的不同部分,并通过将注意力分散到多个“头部”来提高注意力的灵活性。
下面是一个简单的 PyTorch 实现多头注意力机制的例子,假设我们正在构建一个自注意力层:
```python
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = nn.Dropout(dropout)
# 线性变换用于映射Q、K、V
self.query_linear = nn.Linear(embed_dim, embed_dim)
self.key_linear = nn.Linear(embed_dim, embed_dim)
self.value_linear = nn.Linear(embed_dim, embed_dim)
# 子矩阵归一化因子
self.scale = self.head_dim ** -0.5
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
# 计算Q、K、V
query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 归一化和注意力分数计算
attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
if mask is not None:
attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
attention_weights = self.dropout(torch.softmax(attention_scores, dim=-1))
# 汇总注意力得分
context = torch.matmul(attention_weights, value).transpose(1, 2).contiguous().view(
batch_size, -1, self.embed_dim)
return context, attention_weights
# 示例使用
query = torch.rand(64, 10, 768) # (batch_size, seq_len, embed_dim)
key = torch.rand(64, 10, 768)
value = torch.rand(64, 10, 768)
mask = torch.triu(torch.ones((10, 10)), diagonal=1)
att_output, att_weights = multi_head_attention(query, key, value, mask)
```
阅读全文