kg_data['r'] += 2 # 给每个关系的编号+2 self.n_relations = max(kg_data['r']) + 1 # 80 self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1 # 头实体的数量113487 # n_user在loader_base的stastic-cf里 70679 self.n_users_entities = self.n_users + self.n_entities self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), self.cf_train_data[1].astype(np.int32)) self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), self.cf_test_data[1].astype(np.int32)) self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()} self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}
时间: 2024-03-31 20:32:51 浏览: 21
这段代码是在对知识图谱数据进行预处理,主要做了以下几件事情:
1. 将每个关系的编号加2:kg_data['r'] += 2,目的是为了给0~1之间的关系编号腾出空间,使得之后可能新增的关系有空间可用。
2. 计算实体和关系的数量:self.n_entities表示头实体的数量,self.n_relations表示关系的数量。
3. 计算用户和实体数量之和:self.n_users_entities = self.n_users + self.n_entities,用于后面构建用户与实体的交互矩阵。
4. 对训练集和测试集的数据进行处理:将每个实体的编号加上self.n_entities,目的是为了区分实体和用户,同时保证新加入的实体编号不会与原有实体或用户的编号重复。
5. 对训练集和测试集的用户字典进行处理:将字典中每个键对应的值加上self.n_entities,目的同样是为了区分实体和用户。