多头自注意力机制的代码详解
时间: 2024-04-21 20:21:36 浏览: 254
多头自注意力机制(Multi-head Self-Attention)是一种用于处理序列数据的注意力机制,常用于自然语言处理任务中。下面是多头自注意力机制的代码详解:
首先,我们需要导入相关的库和模块:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
```
接下来,我们定义一个多头自注意力机制的类:
```python
class MultiheadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiheadSelfAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
# 定义查询、键、值的线性变换层
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
# 定义输出的线性变换层
self.output_linear = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 对输入进行线性变换得到查询、键、值
query = self.query_linear(x)
key = self.key_linear(x)
value = self.value_linear(x)
# 将查询、键、值分割成多个头
query = query.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# 计算注意力得分
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
# 对值进行加权求和
weighted_values = torch.matmul(attention_weights, value)
# 将多个头的结果拼接起来
weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# 进行输出的线性变换
output = self.output_linear(weighted_values)
return output
```
以上就是多头自注意力机制的代码详解。
阅读全文