def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor: # Compute sentence lengths source_lengths = [len(s) for s in source] # Convert list of lists into tensors source_padded = self.vocab.src
时间: 2024-04-28 09:23:36 浏览: 189
以下是 forward 函数的代码:
```
def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor:
# Compute sentence lengths
source_lengths = [len(s) for s in source]
# Convert list of lists into tensors
source_padded = self.vocab.src.to_input_tensor(source, device=self.device) # shape: (src_len, batch_size)
target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) # shape: (tgt_len, batch_size)
# Embed the source and target sequences
source_embedded = self.model_embeddings_source(source_padded) # shape: (src_len, batch_size, embed_size)
target_embedded = self.model_embeddings_target(target_padded) # shape: (tgt_len, batch_size, embed_size)
# Encode the source sequence
enc_hiddens, dec_init_state = self.encoder(source_embedded, source_lengths) # enc_hiddens shape: (src_len, batch_size, hidden_size)
# Decode the target sequence
combined_outputs = self.decoder(target_embedded, dec_init_state, enc_hiddens) # shape: (tgt_len, batch_size, hidden_size)
# Compute scores
P = F.log_softmax(self.target_vocab_projection(combined_outputs), dim=-1) # shape: (tgt_len, batch_size, tgt_vocab_size)
return P
```
该函数的目标是将源语言和目标语言的句子分别转换为嵌入表示,并使用编码器-解码器模型对目标语言进行解码,最后计算得分。
函数的输入参数包括源语言句子列表 source 和目标语言句子列表 target。函数返回一个 torch.Tensor 类型的得分矩阵 P,其形状为 (tgt_len, batch_size, tgt_vocab_size)。
在函数中,首先计算了源语言句子的长度 source_lengths。然后,使用 vocab 中的 to_input_tensor 函数将源语言和目标语言的句子列表转换为张量。其中,source_padded 张量的形状为 (src_len, batch_size),target_padded 张量的形状为 (tgt_len, batch_size)。
接下来,将源语言和目标语言的张量输入到嵌入层模型中,得到源语言和目标语言的嵌入表示 source_embedded 和 target_embedded。其中,source_embedded 的形状为 (src_len, batch_size, embed_size),target_embedded 的形状为 (tgt_len, batch_size, embed_size)。
然后,将 source_embedded 和 source_lengths 作为输入传递给编码器,得到编码器的输出 enc_hiddens 和解码器的初始状态 dec_init_state。其中,enc_hiddens 的形状为 (src_len, batch_size, hidden_size)。
最后,将 target_embedded、dec_init_state 和 enc_hiddens 作为输入传递给解码器,得到 combined_outputs。然后,将 combined_outputs 输入到目标语言词汇表投影层中,得到得分矩阵 P。最后,使用 F.log_softmax 函数对 P 进行 log-softmax 处理,将最终结果返回。
阅读全文