pytorch的self-attention代码
时间: 2024-08-29 21:00:59 浏览: 85
PyTorch 自注意力(Self-Attention)通常用于Transformer模型中,它是一种计算机制,让每个输入元素(如文本的词)能够直接与其他所有元素交互,提取全局上下文信息。以下是一个简单的自注意力层的代码示例,在Transformer模块中,这通常是`MultiHeadAttention`部分:
```python
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.fc_q = nn.Linear(d_model, d_model) # Query linear layer
self.fc_k = nn.Linear(d_model, d_model) # Key linear layer
self.fc_v = nn.Linear(d_model, d_model) # Value linear layer
self.fc_out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def attention(self, query, key, value, mask=None):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) # Dot product scaled by head_dim
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # Mask padding positions
attn_weights = nn.functional.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
q = self.fc_q(query) # (batch_size, seq_len, d_model)
k = self.fc_k(key) # (batch_size, seq_len, d_model)
v = self.fc_v(value) # (batch_size, seq_len, d_model)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
attn_output, attn_weights = self.attention(q, k, v, mask)
attn_output = attn_output.transpose(1, 2).contiguous() # back to (batch_size, seq_len, d_model*num_heads)
attn_output = self.fc_out(attn_output) # Apply final linear transformation
attn_output = self.dropout(attn_output)
return attn_output, attn_weights
```
阅读全文