pytorch实现transformer代码
时间: 2023-09-15 12:23:34 浏览: 175
用Pytorch实现Transformer
当然可以!下面是一个使用PyTorch实现Transformer模型的示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
super(Transformer, self).__init__()
self.embedding = nn.Embedding(input_dim, hidden_dim)
self.position_encoding = PositionalEncoding(hidden_dim)
self.encoder_layers = nn.ModuleList([EncoderLayer(hidden_dim, num_heads) for _ in range(num_layers)])
self.fc = nn.Linear(hidden_dim, input_dim)
def forward(self, x):
seq_len = x.size(1)
# Embedding and Positional Encoding
x = self.embedding(x)
x = self.position_encoding(x)
# Transformer Encoder
for encoder_layer in self.encoder_layers:
x = encoder_layer(x)
# Output layer
x = self.fc(x)
x = F.log_softmax(x, dim=-1)
return x
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(EncoderLayer, self).__init__()
self.multihead_attention = MultiheadAttention(hidden_dim, num_heads)
self.ffn = FeedForwardNetwork(hidden_dim)
def forward(self, x):
# Multi-head Attention
x = self.multihead_attention(x)
# Feed Forward Network
x = self.ffn(x)
return x
class MultiheadAttention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(MultiheadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.linear_q = nn.Linear(hidden_dim, hidden_dim)
self.linear_k = nn.Linear(hidden_dim, hidden_dim)
self.linear_v = nn.Linear(hidden_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
batch_size = x.size(0)
q = self.linear_q(x)
k = self.linear_k(x)
v = self.linear_v(x)
q = self._split_heads(q, batch_size)
k = self._split_heads(k, batch_size)
v = self._split_heads(v, batch_size)
scaled_attention = self._scaled_dot_product_attention(q, k, v)
scaled_attention = self._concat_heads(scaled_attention, batch_size)
x = self.fc(scaled_attention)
return x
def _split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return x
def _concat_heads(self, x, batch_size):
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
return x
def _scaled_dot_product_attention(self, q, k, v):
dk = torch.sqrt(torch.tensor(self.head_dim).float())
scores = torch.matmul(q, k.transpose(-2, -1)) / dk
attention_weights = F.softmax(scores, dim=-1)
scaled_attention = torch.matmul(attention_weights, v)
return scaled_attention
class FeedForwardNetwork(nn.Module):
def __init__(self, hidden_dim):
super(FeedForwardNetwork, self).__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.hidden_dim = hidden_dim
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x * math.sqrt(self.hidden_dim)
x = x + self.pe[:x.size(0), :]
return x
```
这是一个简化的Transformer模型,其中包含了Encoder层、Multi-head Attention、Feed Forward Network和Positional Encoding等组件。你可以根据需要进行修改和扩展。
希望这个代码对你有帮助!如果你还有其他问题,请随时提问。
阅读全文