自注意力机制代码python
时间: 2023-10-19 08:06:03 浏览: 177
以下是一个简单的自注意力机制的代码示例,使用Python和PyTorch库:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SelfAttention, self).__init__()
self.hidden_dim = hidden_dim
self.query = nn.Linear(input_dim, hidden_dim)
self.key = nn.Linear(input_dim, hidden_dim)
self.value = nn.Linear(input_dim, hidden_dim)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
scores = torch.bmm(query, key.transpose(1, 2))
attention_weights = self.softmax(scores)
attended_values = torch.bmm(attention_weights, value)
return attended_values
# 示例用法
input_dim = 512 # 输入维度
hidden_dim = 128 # 隐藏层维度
batch_size = 10
seq_len = 20
input_data = torch.randn(batch_size, seq_len, input_dim)
self_attention = SelfAttention(input_dim, hidden_dim)
output = self_attention(input_data)
print(output.shape) # 输出形状为 [batch_size, seq_len, hidden_dim]
```
这段代码实现了一个简单的自注意力机制。该模型将输入 x 分别映射到查询(query)、键(key)和值(value)空间,并计算查询与键之间的相似度得分。然后,使用 Softmax 函数将得分归一化为注意力权重,并将注意力权重应用于值以获得最终的注意力值。最后,返回注意力值作为输出。
请注意,这只是一个简化的示例,实际上可能需要进行更多的计算和参数调整,以适应特定的任务和数据集。
阅读全文