注意力机制(Attention)
时间: 2024-03-10 22:42:27 浏览: 34
注意力机制(Attention)是一种在机器学习和自然语言处理中常用的技术,用于模拟人类在处理信息时的注意力分配过程。它通过对输入序列中不同位置的信息进行加权,使得模型能够更加关注与当前任务相关的部分。注意力机制在很多任务中都取得了显著的性能提升,特别是在机器翻译和阅读理解等任务中[^1]。
下面是两个关于注意力机制的例子:
1. 注意力机制在神经机器翻译中的应用:
```python
import torch
import torch.nn as nn
# 定义注意力机制模块
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, hidden, encoder_outputs):
max_len = encoder_outputs.size(0)
batch_size = encoder_outputs.size(1)
# 将隐藏状态复制max_len次,以便与encoder_outputs进行拼接
hidden = hidden.unsqueeze(1).repeat(1, max_len, 1)
# 将隐藏状态和encoder_outputs拼接,并通过线性层计算注意力权重
energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
energy = energy.permute(0, 2, 1) # 转置,以便进行矩阵乘法
v = self.v.repeat(batch_size, 1).unsqueeze(1) # 扩展v向量
attention = torch.bmm(v, energy).squeeze(1) # 计算注意力权重
return torch.softmax(attention, dim=1)
# 使用注意力机制
hidden_size = 256
encoder_outputs = torch.randn(10, 32, hidden_size) # 假设encoder输出为10个长度为hidden_size的向量
hidden = torch.randn(32, hidden_size) # 假设decoder隐藏状态为长度为hidden_size的向量
attention = Attention(hidden_size)
weights = attention(hidden, encoder_outputs)
print(weights.size()) # 输出:torch.Size([32, 10])
```
2. 注意力机制在阅读理解中的应用:
```python
import torch
import torch.nn as nn
# 定义注意力机制模块
class AttentionOverAttention(nn.Module):
def __init__(self, hidden_size):
super(AttentionOverAttention, self).__init__()
self.hidden_size = hidden_size
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, passage, question):
max_len = passage.size(0)
batch_size = passage.size(1)
# 将问题复制max_len次,以便与文章进行拼接
question = question.unsqueeze(1).repeat(1, max_len, 1)
# 将问题和文章拼接,并通过线性层计算注意力权重
energy = torch.tanh(self.attn(torch.cat([passage, question], dim=2)))
energy = energy.permute(0, 2, 1) # 转置,以便进行矩阵乘法
v = self.v.repeat(batch_size, 1).unsqueeze(1) # 扩展v向量
attention = torch.bmm(v, energy).squeeze(1) # 计算注意力权重
return torch.softmax(attention, dim=1)
# 使用注意力机制
hidden_size = 256
passage = torch.randn(10, 32, hidden_size) # 假设文章为10个长度为hidden_size的向量
question = torch.randn(32, hidden_size) # 假设问题为长度为hidden_size的向量
attention = AttentionOverAttention(hidden_size)
weights = attention(passage, question)
print(weights.size()) # 输出:torch.Size([32, 10])
```