基于transformer的自注意力机制代码 实例
时间: 2023-11-12 18:51:47 浏览: 90
以下是一个基于transformer的自注意力机制的示例代码:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.query_projection = nn.Linear(d_model, d_model)
self.key_projection = nn.Linear(d_model, d_model)
self.value_projection = nn.Linear(d_model, d_model)
self.output_projection = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.d_model//self.num_heads)
return x.permute(0, 2, 1, 3)
def scaled_dot_product_attention(self, query, key, value):
dk = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
scores = torch.matmul(query, key.transpose(-2, -1)) / dk
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
query = self.query_projection(query)
key = self.key_projection(key)
value = self.value_projection(value)
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
scaled_attention, attention_weights = self.scaled_dot_product_attention(query, key, value)
scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
scaled_attention = scaled_attention.view(batch_size, -1, self.d_model)
output = self.output_projection(scaled_attention)
return output, attention_weights
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
super(TransformerBlock, self).__init__()
self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
self.dropout1 = nn.Dropout(dropout_rate)
self.normalization1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dff),
nn.ReLU(),
nn.Linear(dff, d_model)
)
self.dropout2 = nn.Dropout(dropout_rate)
self.normalization2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
attention_output, _ = self.multi_head_attention(x, x, x, mask)
attention_output = self.dropout1(attention_output)
normal_output1 = self.normalization1(x + attention_output)
feed_forward_output = self.feed_forward(normal_output1)
feed_forward_output = self.dropout2(feed_forward_output)
normal_output2 = self.normalization2(normal_output1 + feed_forward_output)
return normal_output2
class TransformerEncoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, dropout_rate=0.1):
super(TransformerEncoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = nn.Embedding(input_vocab_size, d_model)
self.pos_encoding = self.positional_encoding(maximum_position_encoding, d_model)
self.encoder_layers = nn.ModuleList([TransformerBlock(d_model, num_heads, dff, dropout_rate)
for _ in range(num_layers)])
def positional_encoding(self, position, d_model):
pe = torch.zeros(position, d_model)
position = torch.arange(0, position, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x, mask=None):
seq_len = x.size(1)
x = self.embedding(x)
x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
x += self.pos_encoding[:, :seq_len, :]
x = F.dropout(x, p=dropout_rate, training=self.training)
for i in range(self.num_layers):
x = self.encoder_layers[i](x, mask)
return x
class Transformer(nn.Module):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, target_vocab_size, dropout_rate=0.1):
super(Transformer, self).__init__()
self.encoder = TransformerEncoder(num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, dropout_rate)
self.decoder_embedding = nn.Embedding(target_vocab_size, d_model)
self.pos_encoding = self.positional_encoding(maximum_position_encoding, d_model)
self.decoder_layers = nn.ModuleList([TransformerBlock(d_model, num_heads, dff, dropout_rate)
for _ in range(num_layers)])
self.output_projection = nn.Linear(d_model, target_vocab_size)
def positional_encoding(self, position, d_model):
pe = torch.zeros(position, d_model)
position = torch.arange(0, position, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x, y, mask=None):
encoder_output = self.encoder(x, mask)
seq_len = y.size(1)
y = self.decoder_embedding(y)
y *= torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
y += self.pos_encoding[:, :seq_len, :]
y = F.dropout(y, p=dropout_rate, training=self.training)
for i in range(num_layers):
y = self.decoder_layers[i](y, mask, encoder_output)
output = self.output_projection(y)
return output
```
这个示例代码定义了一个MultiHeadAttention模块和TransformerBlock模块,这些模块用于自注意力机制的实现。TransformerEncoder模块和Transformer模块是用于实现完整的transformer的编码器和解码器。这个示例代码是基于PyTorch实现的,可以用于训练和预测任务。
阅读全文