class MSMDAERNet(nn.Module): def init(self, pretrained=False, number_of_source=15, number_of_category=4): super(MSMDAERNet, self).init() self.sharedNet = pretrained_CFE(pretrained=pretrained) # for i in range(1, number_of_source): # exec('self.DSFE' + str(i) + '=DSFE()') # exec('self.cls_fc_DSC' + str(i) + '=nn.Linear(32,' + str(number_of_category) + ')') for i in range(number_of_source): exec('self.DSFE' + str(i) + '=DSFE()') exec('self.cls_fc_DSC' + str(i) + '=nn.Linear(32,' + str(number_of_category) + ')') def forward(self, data_src, number_of_source, data_tgt=0, label_src=0, mark=0): ''' description: take one source data and the target data in every forward operation. the mmd loss is calculated between the source data and the target data (both after the DSFE) the discrepency loss is calculated between all the classifiers' results (test on the target data) the cls loss is calculated between the ground truth label and the prediction of the mark-th classifier 之所以target data每一条线都要过一遍是因为要计算discrepency loss, mmd和cls都只要mark-th那条线就行 param {type}: mark: int, the order of the current source data_src: take one source data each time number_of_source: int label_Src: corresponding label data_tgt: target data return {type} ''' mmd_loss = 0 disc_loss = 0 data_tgt_DSFE = [] if self.training == True: # common feature extractor data_src_CFE = self.sharedNet(data_src) data_tgt_CFE = self.sharedNet(data_tgt) # Each domian specific feature extractor # to extract the domain specific feature of target data for i in range(number_of_source): DSFE_name = 'self.DSFE' + str(i) data_tgt_DSFE_i = eval(DSFE_name)(data_tgt_CFE) data_tgt_DSFE.append(data_tgt_DSFE_i) # Use the specific feature extractor # to extract the source data, and calculate the mmd loss DSFE_name = 'self.DSFE' + str(mark) data_src_DSFE = eval(DSFE_name)(data_src_CFE) # mmd_loss += utils.mmd(data_src_DSFE, data_tgt_DSFE[mark]) mmd_loss += utils.mmd_linear(data_src_DSFE, data_tgt_DSFE[mark]) # discrepency loss for i in range(len(data_tgt_DSFE)): if i != mark: disc_loss += torch.mean(torch.abs( F.softmax(data_tgt_DSFE[mark], dim=1) - F.softmax(data_tgt_DSFE[i], dim=1) )) # domain specific classifier and cls_loss DSC_name = 'self.cls_fc_DSC' + str(mark) pred_src = eval(DSC_name)(data_src_DSFE) cls_loss = F.nll_loss(F.log_softmax( pred_src, dim=1), label_src.squeeze()) return cls_loss, mmd_loss, disc_loss中data_tgt_DSFE的长度
时间: 2024-03-07 10:53:38 浏览: 147
data_tgt_DSFE 的长度是 number_of_source,也就是目标数据经过每个领域特定特征提取器后得到的特征向量的个数。在 forward 函数中,for 循环会遍历所有的领域特定特征提取器,将目标数据经过每个领域特定特征提取器后得到的特征向量存储在 data_tgt_DSFE 列表中,因此 data_tgt_DSFE 的长度就等于 number_of_source。
阅读全文