attention层的运用代码举例
时间: 2023-06-27 17:08:15 浏览: 50
下面是使用PyTorch实现的一个简单的self-attention层的示例代码:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.softmax = nn.Softmax(dim=-1)
def forward(self, inputs):
# inputs shape: [batch_size, seq_len, hidden_size]
Q = self.query(inputs) # shape: [batch_size, seq_len, hidden_size]
K = self.key(inputs) # shape: [batch_size, seq_len, hidden_size]
V = self.value(inputs) # shape: [batch_size, seq_len, hidden_size]
# 计算注意力分数
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.hidden_size ** 0.5) # shape: [batch_size, seq_len, seq_len]
# 对注意力分数进行softmax归一化
attention_probs = self.softmax(attention_scores) # shape: [batch_size, seq_len, seq_len]
# 使用注意力权重加权求和
attention_output = torch.matmul(attention_probs, V) # shape: [batch_size, seq_len, hidden_size]
return attention_output
```
在这个示例中,我们定义了一个`SelfAttention`模块,构建时需要指定隐层维度`hidden_size`。在`forward`方法中,我们通过`nn.Linear`定义了三个线性层分别作为query、key和value的变换矩阵。然后,我们根据query、key和value计算注意力分数,对分数进行softmax归一化,最后使用注意力权重加权求和得到attention输出。
相关推荐
![docx](https://img-home.csdnimg.cn/images/20210720083331.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)