h_list = [] t_list = [] r_list = [] self.train_kg_dict = collections.defaultdict(list) self.train_relation_dict = collections.defaultdict(list) for row in self.kg_train_data.iterrows(): h, r, t = row[1] h_list.append(h) t_list.append(t) r_list.append(r) self.train_kg_dict[h].append((t, r)) self.train_relation_dict[r].append((h, t)) self.h_list = torch.LongTensor(h_list) self.t_list = torch.LongTensor(t_list) self.r_list = torch.LongTensor(r_list)
时间: 2024-03-29 12:38:25 浏览: 55
train_list
这段代码看起来像是一个数据处理的部分,有些细节我可能看不出来,但是大致意思是将三元组(头实体,关系,尾实体)从kg_train_data中提取出来,并将它们分别存储在h_list、r_list、t_list三个列表中。同时,还将头实体h作为键,对应的值是一个包含(t, r)元组的列表train_kg_dict[h],表示以h为头实体的所有三元组。同理,将关系r作为键,对应的值是一个包含(h, t)元组的列表train_relation_dict[r],表示所有关系为r的三元组。最后,将h_list、t_list、r_list转化为PyTorch的张量分别存储在self.h_list、self.t_list、self.r_list中。
阅读全文