transformer代码
时间: 2023-10-21 08:06:56 浏览: 76
swin transformer代码加数据集
5星 · 资源好评率100%
以下是一个简单的Transformer模型的代码实现,包括Encoder和Decoder部分:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.out_linear = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# Linear projections
k = self.k_linear(k).view(bs, -1, self.n_heads, self.d_k)
q = self.q_linear(q).view(bs, -1, self.n_heads, self.d_k)
v = self.v_linear(v).view(bs, -1, self.n_heads, self.d_k)
# Transpose to get dimensions bs * n_heads * sl * d_model
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
# Attention
scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
attention = torch.matmul(scores, v)
# Concatenate and linear projection
concat_attention = attention.transpose(1,2).contiguous().view(bs, -1, self.d_model)
output = self.out_linear(concat_attention)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super(FeedForward, self).__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = F.relu(self.linear_1(x))
x = self.dropout(x)
x = self.linear_2(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super(EncoderLayer, self).__init__()
self.multi_head_attention = MultiHeadAttention(d_model, n_heads)
self.feed_forward = FeedForward(d_model)
self.layer_norm_1 = nn.LayerNorm(d_model)
self.layer_norm_2 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Multi-head attention
attn_output = self.multi_head_attention(x, x, x, mask=mask)
attn_output = self.dropout_1(attn_output)
# Residual connection and layer normalization
out1 = self.layer_norm_1(x + attn_output)
# Feed-forward layer
ff_output = self.feed_forward(out1)
ff_output = self.dropout_2(ff_output)
# Residual connection and layer normalization
out2 = self.layer_norm_2(out1 + ff_output)
return out2
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super(DecoderLayer, self).__init__()
self.multi_head_attention_1 = MultiHeadAttention(d_model, n_heads)
self.multi_head_attention_2 = MultiHeadAttention(d_model, n_heads)
self.feed_forward = FeedForward(d_model)
self.layer_norm_1 = nn.LayerNorm(d_model)
self.layer_norm_2 = nn.LayerNorm(d_model)
self.layer_norm_3 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
self.dropout_3 = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Masked multi-head attention
attn_output_1 = self.multi_head_attention_1(x, x, x, mask=tgt_mask)
attn_output_1 = self.dropout_1(attn_output_1)
# Residual connection and layer normalization
out1 = self.layer_norm_1(x + attn_output_1)
# Multi-head attention with encoder output
attn_output_2 = self.multi_head_attention_2(out1, enc_output, enc_output, mask=src_mask)
attn_output_2 = self.dropout_2(attn_output_2)
# Residual connection and layer normalization
out2 = self.layer_norm_2(out1 + attn_output_2)
# Feed-forward layer
ff_output = self.feed_forward(out2)
ff_output = self.dropout_3(ff_output)
# Residual connection and layer normalization
out3 = self.layer_norm_3(out2 + ff_output)
return out3
class Encoder(nn.Module):
def __init__(self, input_dim, d_model, n_layers, n_heads, dropout=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.n_layers = n_layers
self.embedding = nn.Embedding(input_dim, d_model)
self.pos_embedding = nn.Embedding(1000, d_model)
self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, dropout) for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Embedding and position encoding
x = self.embedding(x) * math.sqrt(self.d_model)
pos = torch.arange(0, x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1)
pos = self.pos_embedding(pos)
x = x + pos
x = self.dropout(x)
# Encoder layers
for layer in self.layers:
x = layer(x, mask)
return x
class Decoder(nn.Module):
def __init__(self, output_dim, d_model, n_layers, n_heads, dropout=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.n_layers = n_layers
self.embedding = nn.Embedding(output_dim, d_model)
self.pos_embedding = nn.Embedding(1000, d_model)
self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, dropout) for _ in range(n_layers)])
self.out_linear = nn.Linear(d_model, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Embedding and position encoding
x = self.embedding(x) * math.sqrt(self.d_model)
pos = torch.arange(0, x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1)
pos = self.pos_embedding(pos)
x = x + pos
x = self.dropout(x)
# Decoder layers
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
# Linear projection to output vocabulary
output = self.out_linear(x)
return output
class Transformer(nn.Module):
def __init__(self, input_dim, output_dim, d_model, n_layers, n_heads, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(input_dim, d_model, n_layers, n_heads, dropout)
self.decoder = Decoder(output_dim, d_model, n_layers, n_heads, dropout)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_output = self.encoder(src, src_mask)
output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
return output
```
阅读全文