自注意力机制(Self-Attention)代码
时间: 2024-01-08 14:20:19 浏览: 86
自注意力机制(Self-Attention)是一种注意力机制,用于计算同一序列的表示。下面是一个使用自注意力机制的代码示例:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_dim):
super(SelfAttention, self).__init__()
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
scores = torch.matmul(q, k.transpose(-2, -1))
attention_weights = self.softmax(scores)
output = torch.matmul(attention_weights, v)
return output
# 使用自注意力机制
input_dim = 512
seq_length = 10
batch_size = 32
input_data = torch.randn(batch_size, seq_length, input_dim)
self_attention = SelfAttention(input_dim)
output = self_attention(input_data)
print(output.shape) # 输出:torch.Size([32, 10, 512])
```
这段代码定义了一个名为`SelfAttention`的自注意力机制模块。在`forward`方法中,输入`x`经过线性变换得到查询(query)、键(key)和值(value)的表示。然后,通过计算查询和键的点积得到注意力分数,再经过softmax函数得到注意力权重。最后,将注意力权重与值相乘得到输出。
在示例中,我们使用了一个随机生成的输入数据`input_data`,维度为(batch_size, seq_length, input_dim),其中`batch_size`表示批次大小,`seq_length`表示序列长度,`input_dim`表示输入维度。通过调用`SelfAttention`模块,我们可以得到输出`output`,其维度为(batch_size, seq_length, input_dim)。
阅读全文