multi head attention原理
时间: 2023-09-18 22:05:59 浏览: 139
Multi-head attention是一种在Transformer模型中广泛使用的自注意力机制。它允许模型同时关注不同的位置和表示层次,以捕捉更丰富的语义信息。
Multi-head attention的原理如下:
1. 对输入进行线性映射
2. 分成多个头
3. 计算注意力分数
4. 归一化注意力分数
5. 加权求和
Multi head attention code
Here is a simple implementation of multi-head attention in PyTorch:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# linear transformations
query = self.query(query)
key = self.key(key)
value = self.value(value)
# split into multiple heads
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
# dot product attention
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim).float())
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attention = torch.softmax(scores, dim=-1)
x = torch.matmul(attention, value)
# concatenate attention heads
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.d_model)
# final linear transformation
x = self.fc(x)
return x
This implementation takes as input a `d_model` dimension tensor and splits it into `num_heads` attention heads. The `query`, `key`, and `value` matrices are linearly transformed and split into heads as well. Then, the dot product attention is calculated and the attention heads are concatenated and linearly transformed again.
To use this module in your Transformer, you can simply call it like this:
attn = MultiHeadAttention(d_model=512, num_heads=8)
output = attn(query, key, value)
Self attention 与multi head attention的区别
Self-attention与multi-head attention都是Transformer模型中的关键组件,用于在输入序列中捕捉长距离依赖关系。它们之间的区别在于:
1. Self-attention是指在一个序列中,每个元素都与序列中的其他元素进行交互,从而得到一个加权和表示该元素的上下文信息。而multi-head attention则是将self-attention操作在多个不同的投影空间中执行,以便能够更好地学习不同的表示。
2. Self-attention是通过计算查询、键和值的相似度得到的,然后将值加权求和。而multi-head attention是通过将self-attention操作在多个头中执行,每个头都有自己的查询、键和值矩阵,然后将它们拼接在一起并进行线性变换,得到最终的输出。
3. Self-attention只包含一个操作,因此计算成本较低。而multi-head attention需要执行多个操作,因此计算成本较高。
总体而言,self-attention和multi-head attention都是Transformer模型中非常重要的组件,它们的不同之处在于它们执行的操作数和计算成本。