def eval_psnr(loader, model, eval_type=None): model.eval() if eval_type == 'f1': metric_fn = utils.calc_f1 metric1, metric2, metric3, metric4 = 'f1', 'auc', 'none', 'none' elif eval_type == 'building': metric_fn = utils.calc_fmeasure metric1, metric2, metric3, metric4 = 'build', 'non_build', 'none', 'none' elif eval_type == 'ber': metric_fn = utils.calc_ber metric1, metric2, metric3, metric4 = 'shadow', 'non_shadow', 'ber', 'none' elif eval_type == 'cod': metric_fn = utils.calc_cod metric1, metric2, metric3, metric4 = 'sm', 'em', 'wfm', 'mae' if local_rank == 0: pbar = tqdm(total=len(loader), leave=False, desc='val') else: pbar = None pred_list = [] gt_list = [] for batch in loader: for k, v in batch.items(): batch[k] = v.cuda() inp = batch['inp'] pred = torch.sigmoid(model.infer(inp)) batch_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())] batch_gt = [torch.zeros_like(batch['gt']) for _ in range(dist.get_world_size())] dist.all_gather(batch_pred, pred) pred_list.extend(batch_pred) dist.all_gather(batch_gt, batch['gt']) gt_list.extend(batch_gt) if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() pred_list = torch.cat(pred_list, 1) gt_list = torch.cat(gt_list, 1) result1, result2, result3, result4 = metric_fn(pred_list, gt_list) return result1, result2, result3, result4, metric1, metric2, metric3, metric4
时间: 2024-03-29 13:37:38 浏览: 77
pytorch:model.train和model.eval用法及区别详解
这是一个用于评估模型性能的函数,其输入参数包括一个数据集加载器(loader)、一个模型(model)和一个评估类型(eval_type)。函数根据评估类型选择不同的指标(metric_fn)来评估模型的性能,并返回四个评估结果(result1, result2, result3, result4)和四个指标(metric1, metric2, metric3, metric4)。函数的实现过程中,使用了分布式训练和异步数据加载的技术,以提高计算效率。
阅读全文