改进transformer基本翻译模型代码
时间: 2023-12-03 17:40:54 浏览: 136
以下是改进后的Transformer基本翻译模型代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class Transformer(nn.Module):
def __init__(self, input_dim, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
super().__init__()
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
self.pos_embedding = nn.Embedding(1000, hid_dim)
self.layers = nn.ModuleList([TransformerBlock(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
def forward(self, src, trg, src_mask, trg_mask):
# src = [batch size, src len]
# trg = [batch size, trg len]
# src_mask = [batch size, 1, 1, src len]
# trg_mask = [batch size, 1, trg len, trg len]
batch_size = trg.shape[0]
trg_len = trg.shape[1]
src_len = src.shape[1]
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
for layer in self.layers:
trg, _ = layer(trg, src, trg_mask, src_mask)
output = self.fc_out(trg)
return output
class TransformerBlock(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
super().__init__()
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
self.ff_layer_norm = nn.LayerNorm(hid_dim)
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, src, trg_mask, src_mask):
# trg = [batch size, trg len, hid dim]
# src = [batch size, src len, hid dim]
# trg_mask = [batch size, 1, trg len, trg len]
# src_mask = [batch size, 1, 1, src len]
# self attention
_trg, _ = self.self_attention(trg, trg, trg, trg_mask)
trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
# encoder attention
_trg, attention = self.encoder_attention(trg, src, src, src_mask)
trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
# positionwise feedforward
_trg = self.positionwise_feedforward(trg)
trg = self.ff_layer_norm(trg + self.dropout(_trg))
return trg, attention
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask = None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim = -1)
x = torch.matmul(self.dropout(attention), V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
class PositionwiseFeedforwardLayer(nn.Module):
def __init__(self, hid_dim, pf_dim, dropout):
super().__init__()
self.fc_1 = nn.Linear(hid_dim, pf_dim)
self.fc_2 = nn.Linear(pf_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x = [batch size, seq len, hid dim]
x = self.dropout(torch.relu(self.fc_1(x)))
x = self.fc_2(x)
# x = [batch size, seq len, hid dim]
return x
class TranslationDataset(Dataset):
def __init__(self, src_sentences, trg_sentences, src_vocab, trg_vocab):
self.src_sentences = src_sentences
self.trg_sentences = trg_sentences
self.src_vocab = src_vocab
self.trg_vocab = trg_vocab
def __len__(self):
return len(self.src_sentences)
def __getitem__(self, idx):
src_sentence = self.src_sentences[idx]
trg_sentence = self.trg_sentences[idx]
src_indexes = [self.src_vocab.stoi["<sos>"]] + [self.src_vocab.stoi[word] for word in src_sentence] + [self.src_vocab.stoi["<eos>"]]
trg_indexes = [self.trg_vocab.stoi["<sos>"]] + [self.trg_vocab.stoi[word] for word in trg_sentence] + [self.trg_vocab.stoi["<eos>"]]
return {"src": src_indexes, "trg": trg_indexes}
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch["src"]
trg = batch["trg"]
src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
trg_mask = (trg != TRG.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_pad_mask = torch.ones((batch_size, 1, trg_len, trg_len), device = device)
trg_pad_mask = trg_pad_mask & trg_mask
optimizer.zero_grad()
output = model(src, trg[:,:-1], src_mask, trg_pad_mask[:,:-1,:-1,:])
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch["src"]
trg = batch["trg"]
src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
trg_mask = (trg != TRG.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_pad_mask = torch.ones((batch_size, 1, trg_len, trg_len), device = device)
trg_pad_mask = trg_pad_mask & trg_mask
output = model(src, trg[:,:-1], src_mask, trg_pad_mask[:,:-1,:-1,:])
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval()
if isinstance(sentence, str):
nlp = spacy.load("en_core_web_sm")
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = (src_tensor != src_field.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
enc_src = model.tok_embedding(src_tensor) * model.scale
enc_src += model.pos_embedding(torch.arange(0, src_tensor.shape[1]).unsqueeze(0).to(device))
for layer in model.layers:
enc_src, _ = layer(enc_src, enc_src, src_mask, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor([trg_indexes[-1]]).unsqueeze(0).to(device)
trg_mask = (trg_tensor != trg_field.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
output, attention = model(enc_src, trg_tensor, src_mask, trg_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
# 定义超参数
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
N_LAYERS = 3
N_HEADS = 8
PF_DIM = 512
DROPOUT = 0.1
BATCH_SIZE = 128
CLIP = 1
# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(INPUT_DIM, OUTPUT_DIM, HID_DIM, N_LAYERS, N_HEADS, PF_DIM, DROPOUT, device).to(device)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index = TRG.vocab.stoi["<pad>"])
# 定义训练和验证数据集
train_dataset = TranslationDataset(train_src_sentences, train_trg_sentences, SRC.vocab, TRG.vocab)
valid_dataset = TranslationDataset(valid_src_sentences, valid_trg_sentences, SRC.vocab, TRG.vocab)
# 定义训练和验证数据集的迭代器
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 训练模型
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut6-model.pt')
print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')
# 测试模型
model.load_state_dict(torch.load('tut6-model.pt'))
test_dataset = TranslationDataset(test_src_sentences, test_trg_sentences, SRC.vocab, TRG.vocab)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f}')
# 翻译句子
example_idx = 12
src_sentence = ' '.join(test_src_sentences[example_idx])
trg_sentence = ' '.join(test_trg_sentences[example_idx])
translation, attention = translate_sentence(src_sentence, SRC, TRG, model, device)
print(f'src = {src_sentence}')
print(f'trg = {trg_sentence}')
print(f'predicted trg = {translation}')
# 相关问题:
--相关问题--:
阅读全文