帮我用python 写一个注意力机制代码
时间: 2024-12-20 11:30:53 浏览: 2
注意力机制测试代码(Python & 人工智能)
5星 · 资源好评率100%
注意力机制是一种在深度学习中用于处理序列数据的关键组件,特别是在自然语言处理(NLP)任务中,如Transformer模型。这里是一个简单的自注意力机制的Python实现,基于PyTorch框架:
```python
import torch
from torch import nn
class Attention(nn.Module):
def __init__(self, dim_q, dim_k, dim_v, dropout=0.1):
super().__init__()
self.query = nn.Linear(dim_q, dim_k)
self.key = nn.Linear(dim_k, dim_k)
self.value = nn.Linear(dim_v, dim_v)
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v):
# 计算query-key点积并归一化
scores = torch.matmul(q, k.transpose(-2, -1)) / (k.norm(dim=-1, keepdim=True) + 1e-8)
attn_weights = self.softmax(scores)
attn_weights = self.dropout(attn_weights)
# 使用attention权重加权值向量
output = torch.matmul(attn_weights, v)
return output, attn_weights
# 示例用法
model = Attention(dim_q=64, dim_k=64, dim_v=64)
query = torch.randn(10, 32, 64) # [batch_size, seq_len, dim_q]
key = torch.randn(10, 64, 64) # [batch_size, dim_k, seq_len]
value = torch.randn(10, 64, 64) # [batch_size, dim_v, seq_len]
output, attention_weights = model(query, key, value)
```
这个`Attention`模块接受查询、键和值作为输入,并返回加权后的值以及注意力权重。注意这只是一个基础版的注意力机制,实际应用中可能会包含更多的复杂操作,比如位置编码等。
阅读全文