pytorch写一个attention
时间: 2023-09-24 17:05:40 浏览: 46
以下是一个简单的PyTorch实现的注意力模型:
```
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.attention_weights = nn.Parameter(torch.Tensor(hidden_size, 1))
nn.init.uniform_(self.attention_weights.data, -0.1, 0.1)
def forward(self, encoder_outputs):
# encoder_outputs: [batch_size, seq_len, hidden_size]
# Calculate attention weights
energy = torch.tanh(encoder_outputs)
attention_scores = torch.matmul(energy, self.attention_weights).squeeze(-1) # [batch_size, seq_len]
attention_weights = torch.softmax(attention_scores, dim=-1)
# Calculate context vector
context_vector = torch.matmul(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1) # [batch_size, hidden_size]
return context_vector, attention_weights
```
在这个模型中,注意力权重是通过一个可学习的权重矩阵来计算的。在前向传递中,我们首先将编码器的输出通过一个tanh激活函数,然后将其与注意力权重相乘,并沿着最后一个维度求和以获得注意力得分。接下来,我们将得分通过softmax函数进行归一化,以获得注意力权重。最后,我们将编码器输出和注意力权重相乘,并沿着序列维度求和以获得上下文向量。
注意力模型可以用于各种任务,例如机器翻译、问答和语言生成等。