注意力 pytorch
时间: 2023-10-06 12:11:49 浏览: 65
PyTorch中有许多不同类型的注意力机制可供选择,这取决于你的具体任务和模型架构。以下是几种常见的注意力机制的示例代码:
1. 点积注意力(Dot Product Attention):
```python
import torch
import torch.nn as nn
class DotProductAttention(nn.Module):
def __init__(self):
super(DotProductAttention, self).__init__()
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1))
attention_weights = torch.softmax(scores, dim=-1)
context_vector = torch.matmul(attention_weights, value)
return context_vector, attention_weights
```
2. 缩放点积注意力(Scaled Dot Product Attention):
```python
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
def __init__(self, scale_factor):
super(ScaledDotProductAttention, self).__init__()
self.scale_factor = scale_factor
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale_factor
attention_weights = torch.softmax(scores, dim=-1)
context_vector = torch.matmul(attention_weights, value)
return context_vector, attention_weights
```
3. 多头注意力(Multi-head Attention):
```python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.query_linear = nn.Linear(input_dim, hidden_dim)
self.key_linear = nn.Linear(input_dim, hidden_dim)
self.value_linear = nn.Linear(input_dim, hidden_dim)
self.output_linear = nn.Linear(hidden_dim, input_dim)
def forward(self, query, key, value):
batch_size = query.size(0)
query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.hidden_dim // self.num_heads, dtype=torch.float32))
attention_weights = torch.softmax(scores, dim=-1)
context_vector = torch.matmul(attention_weights, value).transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
output = self.output_linear(context_vector)
return output, attention_weights
```
这些示例代码展示了不同类型的注意力机制的基本实现方式。请注意,具体的注意力机制实现可能会依赖于你的模型架构和任务需求。希望对你有所帮助!
阅读全文