transformer encoder 代码
时间: 2024-01-29 19:09:37 浏览: 71
以下是一个示例的Transformer Encoder的代码:
```python
import torch
import torch.nn as nn
class TransformerEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(input_dim, hidden_dim)
self.positional_encoding = PositionalEncoding(hidden_dim)
self.encoder_layers = nn.ModuleList([
TransformerEncoderLayer(hidden_dim, num_heads) for _ in range(num_layers)
])
def forward(self, input):
embedded_input = self.embedding(input)
encoded_input = self.positional_encoding(embedded_input)
for encoder_layer in self.encoder_layers:
encoded_input = encoder_layer(encoded_input)
return encoded_input
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, max_length=1000):
super(PositionalEncoding, self).__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.positional_encoding = self.generate_positional_encoding()
def forward(self, input):
batch_size, seq_length, _ = input.size()
positional_encoding = self.positional_encoding[:seq_length, :].unsqueeze(0).expand(batch_size, -1, -1)
return input + positional_encoding
def generate_positional_encoding(self):
positional_encoding = torch.zeros(self.max_length, self.hidden_dim)
position = torch.arange(0, self.max_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.hidden_dim, 2).float() * (-math.log(10000.0) / self.hidden_dim))
positional_encoding[:, 0::2] = torch.sin(position * div_term)
positional_encoding[:, 1::2] = torch.cos(position * div_term)
return positional_encoding
class TransformerEncoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(TransformerEncoderLayer, self).__init__()
self.multihead_attention = MultiheadAttention(hidden_dim, num_heads)
self.feed_forward = FeedForward(hidden_dim)
self.layer_norm1 = nn.LayerNorm(hidden_dim)
self.layer_norm2 = nn.LayerNorm(hidden_dim)
def forward(self, input):
attention_output = self.multihead_attention(input)
attention_output = self.layer_norm1(input + attention_output)
feed_forward_output = self.feed_forward(attention_output)
output = self.layer_norm2(attention_output + feed_forward_output)
return output
class MultiheadAttention(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(MultiheadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.query_projection = nn.Linear(hidden_dim, hidden_dim)
self.key_projection = nn.Linear(hidden_dim, hidden_dim)
self.value_projection = nn.Linear(hidden_dim, hidden_dim)
self.output_projection = nn.Linear(hidden_dim, hidden_dim)
def forward(self, input):
batch_size, seq_length, _ = input.size()
query = self.query_projection(input)
key = self.key_projection(input)
value = self.value_projection(input)
query = self.split_heads(query)
key = self.split_heads(key)
value = self.split_heads(value)
scaled_attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
attention_weights = nn.functional.softmax(scaled_attention_scores, dim=-1)
attention_output = torch.matmul(attention_weights, value)
attention_output = self.combine_heads(attention_output)
output = self.output_projection(attention_output)
return output
def split_heads(self, input):
batch_size, seq_length, hidden_dim = input.size()
input = input.view(batch_size, seq_length, self.num_heads, self.head_dim)
return input.transpose(1, 2)
def combine_heads(self, input):
batch_size, _, seq_length, hidden_dim = input.size()
input = input.transpose(1, 2).contiguous()
return input.view(batch_size, seq_length, self.num_heads * self.head_dim)
class FeedForward(nn.Module):
def __init__(self, hidden_dim):
super(FeedForward, self).__init__()
self.hidden_dim = hidden_dim
self.feed_forward = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.ReLU(),
nn.Linear(4 * hidden_dim, hidden_dim)
)
def forward(self, input):
return self.feed_forward(input)
```
阅读全文