给出标签自注意力机制的代码
时间: 2023-10-19 22:05:45 浏览: 58
注意力机制代码 python
以下是一个简单的实现,使用 PyTorch 框架:
```
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, queries, mask):
batch_size = queries.shape[0]
# Split the embedding into self.heads pieces
values = values.reshape(batch_size, -1, self.heads, self.head_dim)
keys = keys.reshape(batch_size, -1, self.heads, self.head_dim)
queries = queries.reshape(batch_size, -1, self.heads, self.head_dim)
values = self.values(values) # (batch_size, value_len, heads, head_dim)
keys = self.keys(keys) # (batch_size, key_len, heads, head_dim)
queries = self.queries(queries) # (batch_size, query_len, heads, heads_dim)
# Compute the dot product attention
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (batch_size, query_len, heads, heads_dim),
# keys shape: (batch_size, key_len, heads, heads_dim)
# energy: (batch_size, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# attention: (batch_size, heads, query_len, key_len)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
batch_size, -1, self.heads * self.head_dim
)
# attention shape: (batch_size, heads, query_len, key_len),
# values shape: (batch_size, value_len, heads, heads_dim)
# out shape: (batch_size, query_len, embed_size)
out = self.fc_out(out)
return out
```
其中,`embed_size` 为输入嵌入向量的维度,`heads` 表示分成多少个头进行注意力计算。在 `forward` 函数中,首先将输入嵌入向量拆分成多个头,然后通过独立的线性层计算出每个头的查询、键和值,接着计算出点积注意力矩阵,再通过 softmax 函数进行归一化,最后将所有头的输出拼接起来,通过一个全连接层得到最终的输出。注意力矩阵中可以传入掩码,以排除无效的位置。
阅读全文