自注意力机制代码实例
时间: 2024-12-17 08:13:25 浏览: 3
自注意力机制(Self-Attention Mechanism),源于Transformer模型,是一种在自然语言处理(NLP)中广泛使用的机制,用于捕捉序列数据之间的长期依赖。它允许每个位置的元素同时关注其他所有位置的信息,而不仅仅是前后的局部信息。
在PyTorch这样的深度学习框架中,你可以通过`torch.nn.MultiheadAttention`模块来实现自注意力。下面是一个简单的代码示例:
```python
import torch
from torch import nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads=8):
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.size()
query = self.query(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = self.key(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = self.value(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力权重
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, -float('inf'))
attn_probs = nn.functional.softmax(attn_weights, dim=-1)
context = torch.matmul(attn_probs, value)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
output = self.out_proj(context)
return output, attn_probs
# 使用示例
query = torch.rand((16, 32, 768))
key = value = torch.rand((16, 32, 768))
attention_module = SelfAttention(768)
output, attention_scores = attention_module(query, key, value)
```
在这个例子中,输入是查询、键和值张量,`mask`用于屏蔽掉某些位置的注意力。`forward`函数返回经过注意力计算后的新特征向量以及注意力得分。
阅读全文