def evaluate(self, datloader_Test): Image_Feature_ALL = [] Image_Name = [] Sketch_Feature_ALL = [] Sketch_Name = [] start_time = time.time() self.eval() for i_batch, sampled_batch in enumerate(datloader_Test): sketch_feature, positive_feature = self.test_forward(sampled_batch) Sketch_Feature_ALL.extend(sketch_feature) #草图特征 模型的 Sketch_Name.extend(sampled_batch['sketch_path']) #草图名 for i_num, positive_name in enumerate(sampled_batch['positive_path']): #遍历正例图像 if positive_name not in Image_Name: Image_Name.append(positive_name) Image_Feature_ALL.append(positive_feature[i_num]) rank = torch.zeros(len(Sketch_Name)) Image_Feature_ALL = torch.stack(Image_Feature_ALL) Image_Feature_ALL = Image_Feature_ALL.view(Image_Feature_ALL.size(0), -1) for num, sketch_feature in enumerate(Sketch_Feature_ALL): s_name = Sketch_Name[num] sketch_query_name = os.path.basename(s_name) # 提取草图路径中的文件名作为查询名称 position_query = -1 for i, image_name in enumerate(Image_Name): if sketch_query_name in os.path.basename(image_name): # 提取图像路径中的文件名进行匹配 position_query = i break if position_query != -1: sketch_feature = sketch_feature.view(1, -1) distance = F.pairwise_distance(sketch_feature, Image_Feature_ALL) target_distance = F.pairwise_distance(sketch_feature, Image_Feature_ALL[position_query].view(1, -1)) rank[num] = distance.le(target_distance).sum() top1 = rank.le(1).sum().item() / rank.shape[0] top10 = rank.le(10).sum().item() / rank.shape[0] print('Time to Evaluate: {}'.format(time.time() - start_time)) return top1, top10
时间: 2024-02-14 14:29:53 浏览: 131
PESQ.zip_Evaluate the audio_evaluation
这是一个evaluate函数的更新版本。在这个版本中,函数接受一个datloader_Test参数,该参数是一个数据加载器,用于加载测试数据集。
首先,函数初始化一些变量,包括用于存储图像特征、图像名称、草图特征和草图名称的列表。然后,函数将模型设置为评估模式。
接下来,函数遍历测试数据集中的每个批次。对于每个批次,函数调用test_forward方法获取草图特征和正样本特征,并将它们分别添加到相应的列表中。同时,函数还将正样本的路径添加到图像名称列表中。
然后,函数遍历所有的草图特征,并根据草图路径提取查询名称。然后,函数在图像名称列表中查找与查询名称匹配的索引,并将其赋值给position_query变量。
接下来,如果position_query不等于-1,则说明找到了与查询名称匹配的正样本图像。函数使用F.pairwise_distance计算草图特征与所有正样本特征之间的距离,并使用F.pairwise_distance计算草图特征与对应正样本特征之间的距离。
然后,函数计算每个草图与所有正样本之间的排序值,并统计排名在前1和前10的比例。
最后,函数打印评估时间,并返回top1和top10的比例。
请注意,这只是代码的一个简单解释,具体实现可能还涉及其他细节。
阅读全文