multi-head self-attention代码
时间: 2023-05-20 11:03:35 浏览: 242
以下是一个简单的 multi-head self-attention 的代码示例:
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# perform linear operation and split into h heads
k = self.k_linear(k).view(bs, -1, self.num_heads, self.d_k)
q = self.q_linear(q).view(bs, -1, self.num_heads, self.d_k)
v = self.v_linear(v).view(bs, -1, self.num_heads, self.d_k)
# transpose to get dimensions bs * h * sl * d_model
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
# calculate attention using function we will define next
scores = self.attention(q, k, v, self.d_k, mask, self.num_heads)
# concatenate heads and put through final linear layer
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.num_heads * self.d_k)
output = self.out(concat)
return output
def attention(self, q, k, v, d_k, mask=None, num_heads=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = nn.functional.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
return output
```
这个代码实现了一个 multi-head self-attention 模块,可以用于构建 Transformer 等模型。
阅读全文