利用pytorch写一个自注意力机制的代码
时间: 2023-05-12 13:06:32 浏览: 120
可以的,以下是一个简单的自注意力机制的代码实现:
import torch
import torch.nn.functional as F
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.query = torch.nn.Linear(hidden_size, hidden_size)
self.key = torch.nn.Linear(hidden_size, hidden_size)
self.value = torch.nn.Linear(hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(0.1)
def forward(self, x):
batch_size, seq_len, hidden_size = x.size()
# Split the hidden size into num_heads
x = x.view(batch_size, seq_len, self.num_heads, self.head_size)
x = x.permute(0, 2, 1, 3)
# Compute query, key, and value
query = self.query(x)
key = self.key(x)
value = self.value(x)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1))
scores = scores / self.head_size**0.5
# Apply mask (optional)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
# Apply softmax
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Compute weighted sum of values
context = torch.matmul(attn_weights, value)
context = context.permute(0, 2, 1, 3).contiguous()
context = context.view(batch_size, seq_len, hidden_size)
return context
注意,这只是一个简单的实现,实际上还有很多细节需要注意,比如如何处理padding等问题。
阅读全文