pytorch 多头自注意力机制代码应用
时间: 2023-09-24 10:07:54 浏览: 125
多头自注意力机制是 Transformer 模型中的重要组成部分,它允许模型在不同的注意力头上学习不同的特征表示。下面是 PyTorch 中应用多头自注意力机制的代码示例:
```python
import torch
import torch.nn as nn
class MultiheadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiheadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_model // num_heads
self.linear_q = nn.Linear(d_model, d_model)
self.linear_k = nn.Linear(d_model, d_model)
self.linear_v = nn.Linear(d_model, d_model)
self.linear_out = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections
query = self.linear_q(query)
key = self.linear_k(key)
value = self.linear_v(value)
# Reshape tensors
query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_q, d_k)
key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_k, d_k)
value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_v, d_k)
# Scaled dot-product attention
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # Apply mask to scores
attention_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, value)
# Reshape and concatenate attention heads
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
# Final linear layer
output = self.linear_out(context)
return output, attention_weights
```
上述代码定义了一个 `MultiheadAttention` 类,该类继承自 PyTorch 的 `nn.Module`。在 `forward` 方法中,首先进行线性变换,然后对输入的 query、key 和 value 进行维度的调整。接下来,通过矩阵乘法计算注意力分数,如果需要的话,可以应用屏蔽(mask)操作。然后使用 softmax 函数计算注意力权重,并将其与 value 相乘得到上下文张量。最后,对上下文张量进行维度调整和线性变换得到最终输出。
你可以根据自己的需求使用这个多头自注意力机制模块来构建自己的 Transformer 模型。
阅读全文