用最少的代码写自注意力机制
时间: 2023-12-17 16:01:15 浏览: 17
import torch
class SelfAttention(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(SelfAttention, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.query = torch.nn.Linear(input_size, hidden_size)
self.key = torch.nn.Linear(input_size, hidden_size)
self.value = torch.nn.Linear(input_size, hidden_size)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, inputs):
Q = self.query(inputs)
K = self.key(inputs)
V = self.value(inputs)
scores = torch.matmul(Q, K.transpose(1,2)) / self.hidden_size**0.5
attn_weights = self.softmax(scores)
output = torch.matmul(attn_weights, V)
return output, attn_weights
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)