Pytorch代码实现Transformer模型
时间: 2024-02-18 10:58:10 浏览: 153
Transformer模型是一种非常流行的深度学习模型,用于处理序列数据,特别是在自然语言处理任务中表现出色。下面是一个简单的PyTorch代码实现Transformer模型的示例:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout):
super(Transformer, self).__init__()
self.embedding = nn.Embedding(input_dim, hidden_dim)
self.positional_encoding = PositionalEncoding(hidden_dim, dropout)
self.encoder = Encoder(hidden_dim, num_layers, num_heads, dropout)
self.decoder = Decoder(hidden_dim, num_layers, num_heads, dropout)
self.fc = nn.Linear(hidden_dim, input_dim)
def forward(self, src, trg):
src_embedded = self.embedding(src)
trg_embedded = self.embedding(trg)
src_encoded = self.positional_encoding(src_embedded)
trg_encoded = self.positional_encoding(trg_embedded)
memory = self.encoder(src_encoded)
output = self.decoder(trg_encoded, memory)
output = self.fc(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, dropout):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * -(math.log(10000.0) / hidden_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class Encoder(nn.Module):
def __init__(self, hidden_dim, num_layers, num_heads, dropout):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([EncoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.norm(x)
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.feed_forward = FeedForward(hidden_dim, dropout)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x):
residual = x
x = self.self_attention(x)
x = residual + x
x = self.norm(x)
residual = x
x = self.feed_forward(x)
x = residual + x
x = self.norm(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
self.fc_k = nn.Linear(hidden_dim, hidden_dim)
self.fc_v = nn.Linear(hidden_dim, hidden_dim)
self.fc_o = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch_size = x.shape[0]
q = self.fc_q(x)
k = self.fc_k(x)
v = self.fc_v(x)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale
attention = F.softmax(energy, dim=-1)
x = torch.matmul(self.dropout(attention), v)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, hidden_dim)
x = self.fc_o(x)
return x
class FeedForward(nn.Module):
def __init__(self, hidden_dim, dropout):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class Decoder(nn.Module):
def __init__(self, hidden_dim, num_layers, num_heads, dropout):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)])
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x, memory):
for layer in self.layers:
x = layer(x, memory)
return self.norm(x)
class DecoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout):
super(DecoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.encoder_attention = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.feed_forward = FeedForward(hidden_dim, dropout)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x, memory):
residual = x
x = self.self_attention(x)
x = residual + x
x = self.norm(x)
residual = x
x = self.encoder_attention(x, memory)
x = residual + x
x = self.norm(x)
residual = x
x = self.feed_forward(x)
x = residual + x
x = self.norm(x)
return x
```
这段代码实现了一个简单的Transformer模型,包括了编码器和解码器部分。你可以根据自己的需求进行修改和扩展。
阅读全文