import math import pandas as pd import torch from torch import nn from d2l import torch as d2l class DecoderBlock(nn.Module): """解码器中第i个块""" def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) def forward(self, X, state): enc_outputs, enc_valid_lens = state[0], state[1] # 训练阶段,输出序列的所有词元都在同一时间处理, # 因此state[2][self.i]初始化为None。 # 预测阶段,输出序列是通过词元一个接着一个解码的, # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示 if state[2][self.i] is None: key_values = X else: key_values = torch.cat((state[2][self.i], X), axis=1) state[2][self.i] = key_values if self.training: batch_size, num_steps, _ = X.shape # dec_valid_lens的开头:(batch_size,num_steps), # 其中每一行是[1,2,...,num_steps] dec_valid_lens = torch.arange( 1, num_steps + 1, device=X.device).repeat(batch_size, 1) else: dec_valid_lens = None # 自注意力 X2 = self.attention1(X, key_values, key_values, dec_valid_lens) Y = self.addnorm1(X, X2) # 编码器-解码器注意力。 # enc_outputs的开头:(batch_size,num_steps,num_hiddens) Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) Z = self.addnorm2(Y, Y2) return self.addnorm3(Z, self.ffn(Z)), state decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0) decoder_blk.eval() X = torch.ones((2, 100, 24)) state = [encoder_blk(X, valid_lens), valid_lens, [None]] decoder_blk(X, state)[0].shape torch.Size([2, 100, 24])
时间: 2024-04-15 11:25:19 浏览: 15
这段代码定义了一个解码器块(DecoderBlock)的类。该解码器块包含了两个多头注意力(MultiHeadAttention)、残差连接与层规范化(AddNorm)和基于位置的前馈网络(PositionWiseFFN)。
在初始化方法中,首先创建了两个多头注意力实例self.attention1和self.attention2,然后创建了三个AddNorm实例self.addnorm1、self.addnorm2和self.addnorm3,分别用于在注意力和前馈网络之后进行残差连接与层规范化。最后创建了一个PositionWiseFFN实例self.ffn。
在前向传播方法中,输入张量X和状态state被输入到两个多头注意力中进行注意力计算,并通过残差连接与层规范化进行处理。在训练阶段,解码器块的每个时间步都会处理所有词元的输出序列,因此state[2][self.i]初始化为None。在预测阶段,解码器块会逐个词元地解码输出序列,因此state[2][self.i]包含了当前时间步之前的解码器块的输出表示。
如果是训练阶段,会根据输入张量X的形状创建一个有效长度张量dec_valid_lens。如果是预测阶段,dec_valid_lens为None。然后通过自注意力计算得到中间结果X2,再通过残差连接与层规范化得到中间结果Y。接下来,使用编码器-解码器注意力计算得到中间结果Y2,再通过残差连接与层规范化得到最终的输出结果Z。
在代码的最后,创建了一个DecoderBlock的实例decoder_blk,并对其进行了评估(eval())。然后,创建了一个大小为(2, 100, 24)的张量X和状态state,并将它们输入到decoder_blk中,并打印出输出张量的形状。
结果是一个大小为(2, 100, 24)的张量,表示解码器块的输出张量的形状与输入张量相同。