自注意力机制代码实现
时间: 2024-04-23 14:20:23 浏览: 125
自注意力机制(Self-Attention)是一种用于处理序列数据的机制,常用于自然语言处理任务中。下面是一个简单的自注意力机制的代码实现示例:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_dim):
super(SelfAttention, self).__init__()
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
scores = torch.matmul(q, k.transpose(-2, -1))
attention_weights = self.softmax(scores)
output = torch.matmul(attention_weights, v)
return output
# 使用示例
input_dim = 512
seq_len = 10
batch_size = 32
input_data = torch.randn(batch_size, seq_len, input_dim)
self_attention = SelfAttention(input_dim)
output = self_attention(input_data)
print(output.shape)
```
上述代码中,`SelfAttention` 类定义了一个自注意力机制模块。在 `forward` 方法中,通过线性变换将输入数据 `x` 映射到查询(query)、键(key)和值(value)空间。然后计算查询和键的相似度得分,并通过 softmax 函数将得分转化为注意力权重。最后,将注意力权重与值相乘得到输出。
相关问题:
1. 什么是自注意力机制?
2. 自注意力机制的作用是什么?
3. 自注意力机制在自然语言处理中的应用有哪些?
4. 自注意力机制与传统的注意力机制有什么区别?
5. 自注意力机制的计算复杂度如何?
阅读全文