即插即用的transformer代码
时间: 2023-10-02 11:10:59 浏览: 128
以下是一个基于PyTorch的Transformer模型的即插即用代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, max_length, d_model=512, num_heads=8, num_layers=6, dropout=0.1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout = dropout
self.embedding_src = nn.Embedding(src_vocab_size, d_model)
self.embedding_tgt = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(max_length, d_model)
self.encoder = Encoder(d_model, num_heads, num_layers, dropout)
self.decoder = Decoder(d_model, num_heads, num_layers, dropout)
self.fc = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt):
src_mask = self.get_src_mask(src)
tgt_mask = self.get_tgt_mask(tgt)
src_emb = self.embedding_src(src)
tgt_emb = self.embedding_tgt(tgt)
src_emb = self.pos_encoding(src_emb)
tgt_emb = self.pos_encoding(tgt_emb)
enc_output = self.encoder(src_emb, src_mask)
dec_output = self.decoder(tgt_emb, enc_output, tgt_mask, src_mask)
output = self.fc(dec_output)
return output
def get_src_mask(self, src):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
return src_mask
def get_tgt_mask(self, tgt):
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = tgt_mask & self.get_subsequent_mask(tgt.size(-1)).type_as(tgt_mask)
return tgt_mask
def get_subsequent_mask(self, size):
subsequent_mask = torch.triu(torch.ones(size, size), 1)
return subsequent_mask
class Encoder(nn.Module):
def __init__(self, d_model, num_heads, num_layers, dropout):
super().__init__()
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForwardNetwork(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
residual = x
x = self.norm1(x)
x = self.self_attn(x, x, x, mask)
x = self.dropout(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.ffn(x)
x = self.dropout(x)
x = residual + x
return x
class Decoder(nn.Module):
def __init__(self, d_model, num_heads, num_layers, dropout):
super().__init__()
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, enc_output, tgt_mask, src_mask):
for layer in self.layers:
x = layer(x, enc_output, tgt_mask, src_mask)
x = self.norm(x)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.enc_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForwardNetwork(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, tgt_mask, src_mask):
residual = x
x = self.norm1(x)
x = self.self_attn(x, x, x, tgt_mask)
x = self.dropout(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.enc_attn(x, enc_output, enc_output, src_mask)
x = self.dropout(x)
x = residual + x
residual = x
x = self.norm3(x)
x = self.ffn(x)
x = self.dropout(x)
x = residual + x
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
q = self.q_linear(q).view(bs, -1, self.num_heads, self.d_k).transpose(1,2)
k = self.k_linear(k).view(bs, -1, self.num_heads, self.d_k).transpose(1,2)
v = self.v_linear(v).view(bs, -1, self.num_heads, self.d_k).transpose(1,2)
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / (self.d_model ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
output = output.transpose(1,2).contiguous().view(bs, -1, self.d_model)
output = self.out(output)
return output
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff=2048):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, max_length, d_model):
super().__init__()
self.max_length = max_length
self.d_model = d_model
position = torch.arange(0, max_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe = torch.zeros(1, max_length, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x * math.sqrt(self.d_model)
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return x
```
这个代码定义了一个Transformer模型,并使用了PyTorch的nn.Module来组织模型结构。在这个模型中,包含了编码器、解码器、多头自注意力、前馈网络、位置编码等组件。你可以根据自己的需要修改这个代码来满足不同的任务需求。
阅读全文