pytorch 中transformer的各个模块代码
时间: 2023-11-17 15:19:11 浏览: 132
在 PyTorch 中,Transformer 模型的核心组件包括 Encoder(编码器)、Decoder(解码器)以及 Self-Attention(自注意力机制)。下面是它们的代码实现:
1. Encoder(编码器):
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout):
super().__init__()
self.input_dim = input_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.pf_dim = pf_dim
self.dropout = dropout
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
self.pos_embedding = nn.Embedding(1000, hid_dim) # 位置编码
self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask):
batch_size = src.shape[0]
src_len = src.shape[1]
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device)
src = self.dropout((self.tok_embedding(src) * torch.sqrt(torch.FloatTensor([self.hid_dim])).to(src.device)) + self.pos_embedding(pos))
for layer in self.layers:
src = layer(src, src_mask)
return src
```
2. Decoder(解码器):
```python
class Decoder(nn.Module):
def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout):
super().__init__()
self.output_dim = output_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.pf_dim = pf_dim
self.dropout = dropout
self.tok_embedding = nn.Embedding(output_dim, hid_dim)
self.pos_embedding = nn.Embedding(1000, hid_dim) # 位置编码
self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)])
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, enc_src, trg_mask, src_mask):
batch_size = trg.shape[0]
trg_len = trg.shape[1]
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(trg.device)
trg = self.dropout((self.tok_embedding(trg) * torch.sqrt(torch.FloatTensor([self.hid_dim])).to(trg.device)) + self.pos_embedding(pos))
for layer in self.layers:
trg, attention = layer(trg, enc_src, trg_mask, src_mask)
output = self.fc_out(trg)
return output, attention
```
3. Self-Attention(自注意力机制):
```python
class SelfAttention(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = F.softmax(energy, dim=-1)
x = torch.matmul(self.dropout(attention), V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
```
这些代码展示了 Transformer 模型的核心组件的实现。需要注意的是,这仅仅是这些模块的基本代码实现,实际使用时还需要进行相应的参数初始化、模型组装等步骤。
阅读全文