def predict_one_img(model, img_dataset, args): # def predict_one_img(model, img_dataset, data, target, args): dataloader = DataLoader(dataset=img_dataset, batch_size=1, num_workers=0, shuffle=False) model.eval() test_dice = DiceAverage(args.n_labels) # target = to_one_hot_3d(label, args.n_labels) with torch.no_grad(): for data ,target in tqdm(dataloader, total=len(dataloader)): # data = data.to(device) data, target = data.float(), target.long() target = to_one_hot_3d(target, args.n_labels) data, target = data.to(device), target.to(device) # print(data.shape) # print(target.shape) output = model(data) # output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size img_dataset.update_result(output.detach().cpu()) pred = img_dataset.recompone_result() pred = torch.argmax(pred, dim=1) pred_img = to_one_hot_3d(pred, args.n_labels) pred_img=pred_img.to(device) test_dice.update(pred_img, target) test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]}) if args.n_labels == 3: test_dice.update({'Dice_tumor': test_dice.avg[2]}) pred = np.asarray(pred.numpy(), dtype='uint8') if args.postprocess: pass # TO DO pred = sitk.GetImageFromArray(np.squeeze(pred, axis=0)) return test_dice, pred
时间: 2023-07-15 22:13:43 浏览: 159
这是一个用于预测单张图像的函数,接收一个模型、一个图像数据集和一些参数作为输入。函数首先将图像数据集加载到一个Dataloader中,然后将模型设置为评估模式并初始化一个DiceAverage对象,该对象用于计算评估指标(这里是Dice系数)。在没有梯度的情况下,对于每个数据和目标对,函数将数据和目标转换为模型所需的格式(浮点数和one-hot编码),并将它们传递给模型进行预测。预测结果被转换回原来的空间分辨率,然后被添加到图像数据集中。最后,预测结果被重新组合成一个numpy数组,并被转换为一个SimpleITK图像对象并返回。如果需要,可以在预测之后对预测结果进行后处理。函数返回测试指标和预测结果。
相关问题
def test(self): load_model(self.model, args.checkpoint) self.model.eval() with torch.no_grad(): rep, user_pool = self.model(self.graph) """ Save embeddings """ user_emb = (rep[:self.model.n_user] + user_pool).cpu().numpy() item_emb = rep[self.model.n_user: self.model.n_user + self.model.n_item].cpu().numpy() with open(f'HGMN-{self.args.dataset}-embeds.pkl', 'wb') as f: pickle.dump({'user_embed': user_emb, 'item_embed': item_emb}, f) """ Save results """ tqdm_dataloader = tqdm(self.testloader) uids, hrs, ndcgs = [], [], [] for iteration, batch in enumerate(tqdm_dataloader, start=1): user_idx, item_idx = batch user = rep[user_idx] + user_pool[user_idx] item = rep[self.model.n_user + item_idx] preds = self.model.predict(user, item) preds_hrs, preds_ndcgs = self.calc_hr_and_ndcg(preds, self.args.topk) hrs += preds_hrs ndcgs += preds_ndcgs uids += user_idx[::101].tolist() with open(f'HGMN-{self.args.dataset}-test.pkl', 'wb') as f: pickle.dump({uid: (hr, ndcg) for uid, hr, ndcg in zip(uids, hrs, ndcgs)}, f)
这是一个 `test` 方法的定义,用于在模型训练过程结束后对测试数据进行评估。
首先,加载模型的权重参数,使用 `load_model(self.model, args.checkpoint)` 方法将参数加载到模型中,并将模型设置为评估模式,即 `self.model.eval()`。
然后,在 `with torch.no_grad()` 上下文管理器中进行以下操作:
1. 使用模型和图数据 `self.graph` 调用模型 `self.model`,得到用户和物品的表示 `rep` 和 `user_pool`。
2. 保存嵌入向量:将用户嵌入向量和物品嵌入向量转换为 NumPy 数组,并使用 pickle 序列化保存到文件中。
3. 保存评估结果:通过遍历测试数据集中的批次,计算并保存每个用户的命中率和 NDCG 值。同时,也保存了每个用户的索引信息。最终将这些结果使用 pickle 序列化保存到文件中。
需要注意的是,在测试过程中,也没有进行模型参数的更新,因此使用了 `torch.no_grad()` 上下文管理器来禁用梯度计算,以提高效率。
这个方法的目的是对模型在测试数据集上的性能进行评估,并保存嵌入向量和评估结果供进一步分析和使用。
Traceback (most recent call last): File "/root/miniconda3/envs/test/bin/yolo", line 8, in <module> sys.exit(entrypoint()) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/cfg/__init__.py", line 394, in entrypoint getattr(model, mode)(**overrides) # default args from model File "/root/miniconda3/envs/test/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/engine/model.py", line 252, in predict return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/engine/predictor.py", line 189, in predict_cli for _ in gen: # running CLI inference without accumulating any outputs (do not modify) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 43, in generator_context response = gen.send(None) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/engine/predictor.py", line 215, in stream_inference self.setup_source(source if source is not None else self.args.source) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/engine/predictor.py", line 197, in setup_source self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/data/build.py", line 158, in load_inference_source dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride) File "/root/miniconda3/envs/test/lib/python3.8/site-packages/ultralytics/yolo/data/dataloaders/stream_loaders.py", line 57, in __init__ raise ConnectionError(f'{st}Failed to open {s}') ConnectionError: 1/1: 0... Failed to open 0 Sentry is attempting to send 2 pending events Waiting up to 2 seconds Press Ctrl-C to quit
这个错误是由于无法打开数据源导致的。具体来说,看起来代码是在尝试使用 YOLO 进行物体检测,但是无法打开数据源(source)。
你可以检查数据源路径是否正确,并确保你有足够的权限来打开该路径中的文件。另外,还可以检查你的网络连接是否正常,如果你正在尝试从远程服务器上访问数据源,那么网络连接问题可能是导致该错误的原因之一。
阅读全文