基于pytorch的Transformer代码实例
时间: 2024-05-09 08:14:00 浏览: 180
用Pytorch实现Transformer
Transformer 是一种非常流行的深度学习模型,被广泛应用于自然语言处理和语音识别等领域。在 PyTorch 中,可以使用 PyTorch 实现 Transformer 模型,下面是一个简单的代码实例:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.depth = d_model // num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, query, key, value):
matmul_qk = torch.matmul(query, key.transpose(-2, -1))
dk = torch.tensor(self.depth, dtype=torch.float32)
scaled_attention_logits = matmul_qk / torch.sqrt(dk)
attention_weights = F.softmax(scaled_attention_logits, dim=-1)
output = torch.matmul(attention_weights, value)
return output
def split_heads(self, x, batch_size):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.transpose(1, 2)
def forward(self, query, key, value):
batch_size = query.shape
query = self.wq(query)
key = self.wk(key)
value = self.wv(value)
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
scaled_attention = self.scaled_dot_product_attention(query, key, value)
scaled_attention = scaled_attention.transpose(1, 2)
concat_attention = scaled_attention.reshape(batch_size, -1, self.d_model)
output = self.fc(concat_attention)
return output
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(TransformerBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, dff),
nn.ReLU(),
nn.Linear(dff, d_model),
)
self.layernorm1 = nn.LayerNorm(d_model)
self.layernorm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(rate)
self.dropout2 = nn.Dropout(rate)
def forward(self, x):
attn_output = self.mha(x, x, x)
attn_output = self.dropout1(attn_output)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output)
out2 = self.layernorm2(out1 + ffn_output)
return out2
class Transformer(nn.Module):
def __init__(self,
input_vocab_size,
target_vocab_size,
max_len_input,
max_len_target,
num_layers=4,
d_model=128,
num_heads=8,
dff=512,
rate=0.1):
super(Transformer, self).__init__()
self.encoder_embedding = nn.Embedding(input_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(target_vocab_size, d_model)
self.pos_encoding_input = PositionalEncoding(max_len_input, d_model)
self.pos_encoding_target = PositionalEncoding(max_len_target, d_model)
self.encoder_layers = nn.ModuleList([TransformerBlock(d_model,
num_heads,
dff,
rate) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([TransformerBlock(d_model,
num_heads,
dff,
rate) for _ in range(num_layers)])
self.final_layer = nn.Linear(d_model, target_vocab_size)
def forward(self,
input_seq,
target_seq,
input_mask=None,
target_mask=None):
input_seq_embd = self.encoder_embedding(input_seq)
input_seq_embd *= torch.sqrt(torch.tensor(self.d_model))
input_seq_embd += self.pos_encoding_input(input_seq_embd)
target_seq_embd = self.decoder_embedding(target_seq)
target_seq_embd *= torch.sqrt(torch.tensor(self.d_model))
target_seq_embd += self.pos_encoding_target(target_seq_embd)
enc_output = input_seq_embd
for i in range(self.num_layers):
enc_output = self.encoder_layers[i](enc_output)
dec_output = target_seq_embd
for i in range(self.num_layers):
dec_output = self.decoder_layers[i](dec_output)
final_output = self.final_layer(dec_output)
return final_output
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x += self.pe[:x.size(0), :]
return x
```
这个代码实例中包括了 Multi-Head Attention、Transformer Block 和 Transformer 等模块,用于实现一个 Transformer 模型。你可以根据需要修改参数和模型结构来适应你的应用场景。
阅读全文