input_token = torch.tensor(input_id).long().to(device)
时间: 2024-06-04 10:08:17 浏览: 22
这行代码的作用是将输入的 ID 序列转换为 PyTorch 的 Tensor 数据类型,并将其放到指定的设备上(比如 GPU)。这样做是为了能够将数据输入到模型中进行推理。其中,`input_id` 是输入的 ID 序列,`device` 是指定的设备。`torch.tensor()` 是将数据转换为 PyTorch 中的 Tensor 格式,`.long()` 是将数据类型转换为整型,`.to(device)` 是将 Tensor 放到指定的设备上。
相关问题
def decode(decoder, decoder_hidden, encoder_outputs, voc, max_length=MAX_LENGTH): decoder_input = torch.LongTensor([[SOS_token]]) decoder_input = decoder_input.to(device) decoded_words = [] decoder_attentions = torch.zeros(max_length, max_length) # TODO: or (MAX_LEN+1, MAX_LEN+1) for di in range(max_length): # decoder_output, decoder_hidden, decoder_attn = decoder(decoder_input, decoder_hidden, encoder_outputs) decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) _, topi = decoder_output.topk(3) # 目的是什么? ni = topi[0][0] if ni == EOS_token: decoded_words.append('<EOS>') break else: decoded_words.append(voc.index2word[ni.item()]) decoder_input = torch.LongTensor([[ni]]) decoder_input = decoder_input.to(device) return decoded_words, decoder_attentions[:di + 1]
这段代码实现了一个解码函数 `decode`,用于在序列到序列模型中生成输出序列。下面是对该函数的解释:
- `decode` 函数接受以下参数:
- `decoder`:解码器模型
- `decoder_hidden`:解码器的初始隐藏状态
- `encoder_outputs`:编码器的输出
- `voc`:词汇表对象,用于将索引转换为单词
- `max_length`:生成序列的最大长度,默认为预定义的最大长度 `MAX_LENGTH`
- 首先,创建一个张量 `decoder_input`,其中只包含起始标记 SOS_token。
- 将 `decoder_input` 移动到适当的设备上。
- 创建两个空列表 `decoded_words` 和 `decoder_attentions`,用于存储解码后的单词和注意力权重。
- 使用一个循环来生成序列。在每个时间步中,执行以下操作:
- 使用解码器模型、隐藏状态和编码器的输出计算解码器的输出和下一个隐藏状态。
- 从解码器的输出中选择前三个最高值作为候选词的索引。
- 选择候选词中的第一个作为当前时间步的输出。
- 如果当前时间步的输出为 EOS_token,表示已经生成了结束标记,将 '<EOS>' 添加到 `decoded_words` 中并终止循环。
- 否则,将当前时间步的输出对应的单词添加到 `decoded_words` 中。
- 在生成完整的序列后,返回 `decoded_words` 和注意力权重 `decoder_attentions`。
这段代码使用了一个循环来逐步生成输出序列,并且在每个时间步只生成一个单词。生成的序列可以通过 `decoded_words` 获取,注意力权重可以通过 `decoder_attentions` 获取。
希望以上解释对您有帮助!如果您还有其他问题,请随时提问。
改进transformer基本翻译模型代码
以下是改进后的Transformer基本翻译模型代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class Transformer(nn.Module):
def __init__(self, input_dim, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
super().__init__()
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
self.pos_embedding = nn.Embedding(1000, hid_dim)
self.layers = nn.ModuleList([TransformerBlock(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
def forward(self, src, trg, src_mask, trg_mask):
# src = [batch size, src len]
# trg = [batch size, trg len]
# src_mask = [batch size, 1, 1, src len]
# trg_mask = [batch size, 1, trg len, trg len]
batch_size = trg.shape[0]
trg_len = trg.shape[1]
src_len = src.shape[1]
pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
for layer in self.layers:
trg, _ = layer(trg, src, trg_mask, src_mask)
output = self.fc_out(trg)
return output
class TransformerBlock(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
super().__init__()
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
self.ff_layer_norm = nn.LayerNorm(hid_dim)
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, src, trg_mask, src_mask):
# trg = [batch size, trg len, hid dim]
# src = [batch size, src len, hid dim]
# trg_mask = [batch size, 1, trg len, trg len]
# src_mask = [batch size, 1, 1, src len]
# self attention
_trg, _ = self.self_attention(trg, trg, trg, trg_mask)
trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
# encoder attention
_trg, attention = self.encoder_attention(trg, src, src, src_mask)
trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
# positionwise feedforward
_trg = self.positionwise_feedforward(trg)
trg = self.ff_layer_norm(trg + self.dropout(_trg))
return trg, attention
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask = None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim = -1)
x = torch.matmul(self.dropout(attention), V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
class PositionwiseFeedforwardLayer(nn.Module):
def __init__(self, hid_dim, pf_dim, dropout):
super().__init__()
self.fc_1 = nn.Linear(hid_dim, pf_dim)
self.fc_2 = nn.Linear(pf_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x = [batch size, seq len, hid dim]
x = self.dropout(torch.relu(self.fc_1(x)))
x = self.fc_2(x)
# x = [batch size, seq len, hid dim]
return x
class TranslationDataset(Dataset):
def __init__(self, src_sentences, trg_sentences, src_vocab, trg_vocab):
self.src_sentences = src_sentences
self.trg_sentences = trg_sentences
self.src_vocab = src_vocab
self.trg_vocab = trg_vocab
def __len__(self):
return len(self.src_sentences)
def __getitem__(self, idx):
src_sentence = self.src_sentences[idx]
trg_sentence = self.trg_sentences[idx]
src_indexes = [self.src_vocab.stoi["<sos>"]] + [self.src_vocab.stoi[word] for word in src_sentence] + [self.src_vocab.stoi["<eos>"]]
trg_indexes = [self.trg_vocab.stoi["<sos>"]] + [self.trg_vocab.stoi[word] for word in trg_sentence] + [self.trg_vocab.stoi["<eos>"]]
return {"src": src_indexes, "trg": trg_indexes}
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch["src"]
trg = batch["trg"]
src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
trg_mask = (trg != TRG.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_pad_mask = torch.ones((batch_size, 1, trg_len, trg_len), device = device)
trg_pad_mask = trg_pad_mask & trg_mask
optimizer.zero_grad()
output = model(src, trg[:,:-1], src_mask, trg_pad_mask[:,:-1,:-1,:])
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch["src"]
trg = batch["trg"]
src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
trg_mask = (trg != TRG.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_pad_mask = torch.ones((batch_size, 1, trg_len, trg_len), device = device)
trg_pad_mask = trg_pad_mask & trg_mask
output = model(src, trg[:,:-1], src_mask, trg_pad_mask[:,:-1,:-1,:])
output_dim = output.shape[-1]
output = output.contiguous().view(-1, output_dim)
trg = trg[:,1:].contiguous().view(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval()
if isinstance(sentence, str):
nlp = spacy.load("en_core_web_sm")
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = (src_tensor != src_field.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
enc_src = model.tok_embedding(src_tensor) * model.scale
enc_src += model.pos_embedding(torch.arange(0, src_tensor.shape[1]).unsqueeze(0).to(device))
for layer in model.layers:
enc_src, _ = layer(enc_src, enc_src, src_mask, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor([trg_indexes[-1]]).unsqueeze(0).to(device)
trg_mask = (trg_tensor != trg_field.vocab.stoi["<pad>"]).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
output, attention = model(enc_src, trg_tensor, src_mask, trg_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
# 定义超参数
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
N_LAYERS = 3
N_HEADS = 8
PF_DIM = 512
DROPOUT = 0.1
BATCH_SIZE = 128
CLIP = 1
# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(INPUT_DIM, OUTPUT_DIM, HID_DIM, N_LAYERS, N_HEADS, PF_DIM, DROPOUT, device).to(device)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index = TRG.vocab.stoi["<pad>"])
# 定义训练和验证数据集
train_dataset = TranslationDataset(train_src_sentences, train_trg_sentences, SRC.vocab, TRG.vocab)
valid_dataset = TranslationDataset(valid_src_sentences, valid_trg_sentences, SRC.vocab, TRG.vocab)
# 定义训练和验证数据集的迭代器
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 训练模型
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut6-model.pt')
print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')
# 测试模型
model.load_state_dict(torch.load('tut6-model.pt'))
test_dataset = TranslationDataset(test_src_sentences, test_trg_sentences, SRC.vocab, TRG.vocab)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f}')
# 翻译句子
example_idx = 12
src_sentence = ' '.join(test_src_sentences[example_idx])
trg_sentence = ' '.join(test_trg_sentences[example_idx])
translation, attention = translate_sentence(src_sentence, SRC, TRG, model, device)
print(f'src = {src_sentence}')
print(f'trg = {trg_sentence}')
print(f'predicted trg = {translation}')
# 相关问题:
--相关问题--:
相关推荐
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)