可否给我transformer的代码
时间: 2023-05-20 21:03:23 浏览: 51
当然可以,以下是一个简单的 Transformer 模型的代码:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_heads, dropout):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.embedding = nn.Embedding(input_dim, hidden_dim)
self.pos_encoding = PositionalEncoding(hidden_dim, dropout)
self.encoder_layers = nn.ModuleList([EncoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)])
self.fc_out = nn.Linear(hidden_dim, output_dim)
def forward(self, src, trg, src_mask, trg_mask):
src_embedded = self.embedding(src) * torch.sqrt(self.hidden_dim)
src_embedded = self.pos_encoding(src_embedded)
trg_embedded = self.embedding(trg) * torch.sqrt(self.hidden_dim)
trg_embedded = self.pos_encoding(trg_embedded)
for layer in self.encoder_layers:
src_embedded = layer(src_embedded, src_mask)
for layer in self.decoder_layers:
trg_embedded = layer(trg_embedded, src_embedded, trg_mask, src_mask)
output = self.fc_out(trg_embedded)
return output
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.feed_forward = FeedForward(hidden_dim, dropout)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src_embedded, src_mask):
attn_output, _ = self.self_attn(src_embedded, src_embedded, src_embedded, src_mask)
attn_output = self.dropout(attn_output)
src_embedded = self.layer_norm(src_embedded + attn_output)
ff_output = self.feed_forward(src_embedded)
ff_output = self.dropout(ff_output)
src_embedded = self.layer_norm(src_embedded + ff_output)
return src_embedded
class DecoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.src_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.feed_forward = FeedForward(hidden_dim, dropout)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, trg_embedded, src_embedded, trg_mask, src_mask):
self_attn_output, _ = self.self_attn(trg_embedded, trg_embedded, trg_embedded, trg_mask)
self_attn_output = self.dropout(self_attn_output)
trg_embedded = self.layer_norm(trg_embedded + self_attn_output)
src_attn_output, _ = self.src_attn(trg_embedded, src_embedded, src_embedded, src_mask)
src_attn_output = self.dropout(src_attn_output)
trg_embedded = self.layer_norm(trg_embedded + src_attn_output)
ff_output = self.feed_forward(trg_embedded)
ff_output = self.dropout(ff_output)
trg_embedded = self.layer_norm(trg_embedded + ff_output)
return trg_embedded
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
self.dropout = nn.Dropout(dropout)
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
self.fc_k = nn.Linear(hidden_dim, hidden_dim)
self.fc_v = nn.Linear(hidden_dim, hidden_dim)
self.fc_o = nn.Linear(hidden_dim, hidden_dim)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = self.dropout(F.softmax(energy, dim=-1))
x = torch.matmul(attention, V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hidden_dim)
x = self.fc_o(x)
return x, attention
class FeedForward(nn.Module):
def __init__(self, hidden_dim, dropout):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.layer_norm(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, dropout, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
```
这是一个基本的 Transformer 模型,包括编码器、解码器、多头注意力、前馈神经网络、位置编码等模块。