自注意力机制python代码
时间: 2023-10-23 22:10:35 浏览: 185
自注意力机制的Python代码如下:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_size, hidden_size, num_attention_heads):
super(SelfAttention, self).__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = hidden_size
self.query = nn.Linear(input_size, self.all_head_size)
self.key = nn.Linear(input_size, self.all_head_size)
self.value = nn.Linear(input_size, self.all_head_size)
def forward(self, inputs):
q = self.query(inputs)
k = self.key(inputs)
v = self.value(inputs)
q = q.view(q.size(0), -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
k = k.view(k.size(0), -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
v = v.view(v.size(0), -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
attention_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
context = torch.matmul(attention_probs, v)
context = context.transpose(1, 2).contiguous().view(context.size(0), -1, self.all_head_size)
return context
```
阅读全文