Multi head attention code
时间: 2023-07-15 17:11:28 浏览: 107
Here is a simple implementation of multi-head attention in PyTorch:
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# linear transformations
query = self.query(query)
key = self.key(key)
value = self.value(value)
# split into multiple heads
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
# dot product attention
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim).float())
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attention = torch.softmax(scores, dim=-1)
x = torch.matmul(attention, value)
# concatenate attention heads
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.d_model)
# final linear transformation
x = self.fc(x)
return x
```
This implementation takes as input a `d_model` dimension tensor and splits it into `num_heads` attention heads. The `query`, `key`, and `value` matrices are linearly transformed and split into heads as well. Then, the dot product attention is calculated and the attention heads are concatenated and linearly transformed again.
To use this module in your Transformer, you can simply call it like this:
```python
attn = MultiHeadAttention(d_model=512, num_heads=8)
output = attn(query, key, value)
```
阅读全文