自注意力层pytorch
时间: 2025-01-08 07:37:27 浏览: 1
### 如何在 PyTorch 中实现自注意力层
#### 创建自定义的自注意力模块
为了实现在PyTorch中的自注意力机制,可以构建一个名为`SelfAttention`的新类继承于`nn.Module`。此过程涉及到初始化权重矩阵Q(查询)、K(键)以及V(值),这些用于计算输入特征之间的相似度得分。
```python
import torch
from torch import nn, Tensor
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size: int, heads: int) -> None:
super(SelfAttention, self).__init__()
assert (embed_size % heads == 0), "Embedding size needs to be divisible by number of heads"
# 定义每个头的维度大小
self.head_dim = embed_size // heads
self.embed_size = embed_size
self.num_heads = heads
# 初始化线性变换层来获取Q,K,V向量
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# 输出投影层
self.fc_out = nn.Linear(embed_size, embed_size)[^2]
def forward(self, values: Tensor, keys: Tensor, query: Tensor, mask=None) -> Tensor:
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 将嵌入分割成多头并调整形状以便后续处理
values = values.reshape(N, value_len, self.num_heads, self.head_dim)
keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
queries= query.reshape(N, query_len, self.num_heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.embed_size
)
out = self.fc_out(out)
return out
```
这段代码展示了如何创建一个多头自注意机制,在这里`einsum`函数被用来高效地执行张量间的乘法运算[^3]。对于每一个头部,分别应用了不同的线性转换以获得各自的Q、K、V表示形式;之后再利用缩放点积注意力公式计算注意力分布,并最终通过另一个全连接层映射回原始尺寸的空间内[^1]。
阅读全文