帮我找一段用注意力机制执行分类任务的代码
时间: 2024-09-10 07:04:49 浏览: 41
注意力机制常用于深度学习模型中,特别是在处理序列数据时,比如自然语言处理任务。下面是一个简单的使用PyTorch实现的注意力机制示例,这个例子展示了一个基本的自注意力层(Self-Attention Layer),应用于文本分类任务:
```python
import torch
from torch import nn
class Attention(nn.Module):
def __init__(self, d_model, heads=8):
super(Attention, self).__init__()
self.d_model = d_model
self.heads = heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.fc = nn.Linear(heads * d_model, d_model)
def forward(self, q, k, v):
batch_size = q.size(0)
q = self.wq(q).view(batch_size, -1, self.heads, self.d_model // self.heads)
k = self.wk(k).view(batch_size, -1, self.heads, self.d_model // self.heads)
v = self.wv(v).view(batch_size, -1, self.heads, self.d_model // self.heads)
# 计算注意力得分并softmax
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_model ** 0.5)
attn_weights = nn.functional.softmax(scores, dim=-1)
# 归一化后的加权值与值向量相乘,然后拼接
context = torch.matmul(attn_weights, v)
context = context.permute(0, 2, 1, 3).contiguous()
new_context = context.view(batch_size, -1, self.d_model)
return self.fc(new_context)
# 示例在文本分类任务中如何使用
input_q = torch.randn(64, 100) # Query sequence (batch_size, seq_len, embedding_dim)
input_k = input_v = torch.randn(64, 200) # Key and value sequences (batch_size, seq_len, embedding_dim)
attention_layer = Attention(100) # Assuming a d_model of 100
output = attention_layer(input_q, input_k, input_v)
```
请注意,这只是一个基础示例,实际应用可能需要更复杂的结构,如Transformer中的Multi-Head Attention,并且通常会和其他神经网络层(如MLP、池化等)结合在一起作为整个模型的一部分。
阅读全文