se注意力机制改进代码
时间: 2023-07-07 09:45:15 浏览: 98
以下是一个简单的 self-attention 机制的改进代码,使用了 multi-head attention:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
def forward(self, values, keys, queries, mask):
# Split embedding into self.heads pieces
batch_size = queries.shape[0]
values = values.reshape(batch_size, -1, self.heads, self.head_dim)
keys = keys.reshape(batch_size, -1, self.heads, self.head_dim)
queries = queries.reshape(batch_size, -1, self.heads, self.head_dim)
# Calculate energy for each head
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# Calculate value for each head
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(batch_size, -1, self.heads*self.head_dim)
# Concatenate all heads output
out = self.fc_out(out)
return out
```
在这个代码中,我们将输入的 embedding 分成了多个 head,每个 head 的 embedding 的维度为 `embed_size // heads`。然后我们对每个 head 计算了 query、key 和 value。我们使用 `torch.einsum` 来计算 energy 和 out,这个函数可以让我们避免使用 for 循环,从而加快计算速度。最后,我们将每个 head 的输出进行拼接,然后通过一个全连接层输出最终的 attention 特征。
阅读全文