self attention的torch代码
时间: 2023-08-27 16:07:11 浏览: 74
基于attention+ gru文本分类代码
5星 · 资源好评率100%
以下是使用PyTorch实现self-attention的代码:
```
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
# Get number of training examples
N = query.shape[0]
# Split embedding into self.heads pieces
values = values.reshape(N, -1, self.heads, self.head_dim)
keys = keys.reshape(N, -1, self.heads, self.head_dim)
query = query.reshape(N, -1, self.heads, self.head_dim)
# Compute the energy between queries and keys for each head
energy = torch.einsum("nqhd,nkhd->nhqk", [query, keys])
# Add the mask to the energy (if present)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Apply softmax to get attention scores
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# Compute the weighted average of the values
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, -1, self.heads*self.head_dim
)
# Apply output fully connected layer
out = self.fc_out(out)
return out
```
这是一个名为`SelfAttention`的PyTorch模块,它实现了self-attention的前向传递。在`__init__`函数中,我们定义了值、键和查询矩阵的线性层。在`forward`函数中,我们首先将输入张量分为self.heads个头,然后计算每个头的能量。接下来,我们应用softmax函数来获取注意力分数,并计算加权平均值。最后,我们将输出送入一个全连接层,以获得最终的self-attention表示。
阅读全文