pytorch self attention
时间: 2024-12-27 15:24:40 浏览: 9
### 实现和使用Self-Attention机制
在PyTorch中实现自注意力(Self-Attention)机制涉及创建一个模块,该模块接受输入序列并计算查询(Q)、键(K)和值(V)。这些张量用于计算注意力权重,并最终生成加权后的特征表示。
#### 创建Self-Attention层
为了构建自注意力机制,在`nn.Module`类的基础上定义一个新的类。此新类初始化时设置线性变换参数矩阵W_q、W_k和W_v,分别对应于查询、键和值的投影[^4]:
```python
import torch.nn as nn
import torch
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
# 定义三个全连接层来映射到 Q,K,V
self.W_q = nn.Linear(embed_size, embed_size)
self.W_k = nn.Linear(embed_size, embed_size)
self.W_v = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
```
对于前向传播过程中的每一个时间步t,通过上述定义好的线性层将输入嵌入转换成对应的Q、K、V矩阵。接着按照缩放点积注意公式计算得分S,并应用softmax函数得到概率分布A作为注意力权重。最后利用这个权重对V做加权求和获得上下文向量C_t:
```python
def forward(self, values, keys, query, mask=None):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values = self.W_v(values)
keys = self.W_k(keys)
queries = self.W_q(query)
# Split embedding into self.heads pieces.
values = values.reshape(N, value_len, self.heads, self.embed_size // self.heads)
keys = keys.reshape(N, key_len, self.heads, self.embed_size // self.heads)
queries = queries.reshape(N, query_len, self.heads, self.embed_size // self.heads)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.embed_size
)
out = self.fc_out(out)
return out
```
在这个例子中,假设已经有一个预训练过的词嵌入表或随机初始化的嵌入层,可以将其传递给上面定义的`SelfAttention`实例来进行编码器部分的操作。
阅读全文