def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx): exist_heads = kg_dict.keys() # 头实体 if batch_size <= len(exist_heads): batch_head = random.sample(exist_heads, batch_size) # 从exist_heads中挑batch_size个样本 else: batch_head = [random.choice(exist_heads) for _ in range(batch_size)] batch_relation, batch_pos_tail, batch_neg_tail = [], [], [] for h in batch_head: relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1) # 给每个三元组都只找一个关系和正尾实体 batch_relation += relation batch_pos_tail += pos_tail neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx) batch_neg_tail += neg_tail batch_head = torch.LongTensor(batch_head) batch_relation = torch.LongTensor(batch_relation) batch_pos_tail = torch.LongTensor(batch_pos_tail) batch_neg_tail = torch.LongTensor(batch_neg_tail) return batch_head, batch_relation, batch_pos_tail, batch_neg_tail
时间: 2024-04-02 13:37:29 浏览: 156
generate_sequence.rar_generate Sequence_generate_sequence_markov
这段代码是用于生成知识图谱(KG)训练的batch数据的,可以看出其使用了随机采样的方式来选取batch中的头实体,然后针对每个头实体,从KG中随机选择一个正例三元组(即包含该头实体的三元组),并从KG中选择一个负例三元组(即不包含该头实体的三元组),最终返回四个Tensor类型的数据,分别是batch中的头实体、关系、正例尾实体和负例尾实体。
阅读全文