带mask的自注意力机制代码
时间: 2024-04-27 15:17:53 浏览: 15
带mask的自注意力机制是一种常用于自然语言处理任务的模型,其中最著名的是Transformer模型。下面是一个简化的带mask的自注意力机制的代码示例:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SelfAttention, self).__init__()
self.query = nn.Linear(input_dim, hidden_dim)
self.key = nn.Linear(input_dim, hidden_dim)
self.value = nn.Linear(input_dim, hidden_dim)
def forward(self, inputs, mask):
q = self.query(inputs)
k = self.key(inputs)
v = self.value(inputs)
scores = torch.matmul(q, k.transpose(-2, -1))
# 应用mask,将无效位置的得分置为负无穷大
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attended = torch.matmul(attn_weights, v)
return attended
# 使用示例
input_dim = 512
hidden_dim = 64
seq_len = 10
batch_size = 16
# 假设有一个输入序列,形状为(batch_size, seq_len, input_dim)
inputs = torch.randn(batch_size, seq_len, input_dim)
# 假设有一个mask矩阵,形状为(batch_size, seq_len),其中1表示有效位置,0表示无效位置
mask = torch.ones(batch_size, seq_len)
mask[:, 5:] = 0
# 创建SelfAttention模型并进行前向计算
self_attention = SelfAttention(input_dim, hidden_dim)
attended = self_attention(inputs, mask)
print(attended.shape) # 输出形状为(batch_size, seq_len, hidden_dim)
```
带mask的自注意力机制的代码示例中,SelfAttention模型接收输入`inputs`和对应的mask矩阵,并输出经过注意力权重加权后的结果`attended`。在计算注意力权重时,根据mask矩阵将无效位置的得分置为负无穷大,以使无效位置在softmax计算时得到的概率接近于0。