我想用pytorch实现机器翻译的任务,用transformer
时间: 2023-09-16 10:09:11 浏览: 50
模型可以实现。以下是一个简单的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
# 定义输入和输出语言的Field
SRC = Field(tokenize="spacy", tokenizer_language="de", init_token="<sos>", eos_token="<eos>", lower=True)
TRG = Field(tokenize="spacy", tokenizer_language="en", init_token="<sos>", eos_token="<eos>", lower=True)
# 加载数据集
train_data, valid_data, test_data = Multi30k.splits(exts=(".de", ".en"), fields=(SRC, TRG))
# 构建词汇表
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)
# 定义模型
class Transformer(nn.Module):
def __init__(self, src_vocab_size, trg_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout), num_layers=num_encoder_layers)
self.decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout), num_layers=num_decoder_layers)
self.src_embed = nn.Embedding(src_vocab_size, d_model)
self.trg_embed = nn.Embedding(trg_vocab_size, d_model)
self.generator = nn.Linear(d_model, trg_vocab_size)
self.d_model = d_model
def forward(self, src, trg, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
src = self.src_embed(src) * (self.d_model ** 0.5)
trg = self.trg_embed(trg) * (self.d_model ** 0.5)
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(trg, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
return self.generator(output)
# 定义超参数
BATCH_SIZE = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1
lr = 0.0005
epochs = 10
# 定义模型、优化器、损失函数
model = Transformer(src_vocab_size, trg_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi["<pad>"])
# 定义数据迭代器
train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)
# 训练模型
for epoch in range(epochs):
for i, batch in enumerate(train_iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output = model(src, trg[:-1])
loss = criterion(output.reshape(-1, trg_vocab_size), trg[1:].reshape(-1))
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Epoch: {epoch+1}/{epochs}, Batch: {i+1}/{len(train_iterator)}, Loss: {loss.item()}")
# 测试模型
def translate_sentence(sentence, src_field, trg_field, model, device, max_len=50):
model.eval()
if isinstance(sentence, str):
tokens = [token.text.lower() for token in src_field.tokenize(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
src_mask = model.generate_square_subsequent_mask(src_tensor.size(0)).to(device)
with torch.no_grad():
memory = model.encoder(src_tensor, mask=src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor([trg_indexes[-1]]).unsqueeze(1).to(device)
trg_mask = model.generate_square_subsequent_mask(trg_tensor.size(0)).to(device)
with torch.no_grad():
output = model.decoder(trg_tensor, memory, tgt_mask=trg_mask)
pred_token = output.argmax(2)[-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:-1]
sentence = "Eine Gruppe von Menschen steht vor einem Iglu ."
translation = translate_sentence(sentence, SRC, TRG, model, device)
print(" ".join(translation))
```
需要注意的是,这个示例代码只是一个基本的实现,实际上需要进行更多的调整和优化才能得到更好的结果。