if train_cfg.flownet == '2sd': flow_net = FlowNet2SD(batchNorm=False, div_flow=20, channel_n=channel_n) flow_net.load_state_dict(torch.load('models/flownet2/FlowNet2-SD.pth')['state_dict']) else: flow_net = lite_flow.Network() flow_net.load_state_dict(torch.load('models/liteFlownet/network-default.pytorch')) flow_net.cuda().eval() # U
时间: 2024-04-04 08:35:07 浏览: 150
ctce_cfg.rar
这段代码是用来加载光流模型的。首先判断使用的光流模型类型,如果是'2sd',则加载FlowNet2SD模型;否则,加载liteFlowNet模型。
在加载模型之后,将其移动到GPU上,并设置为评估模式(eval)。这样,在进行光流计算时,就可以直接使用该模型进行计算。
值得注意的是,这段代码中使用了torch.load()函数来加载模型参数,该函数默认加载CPU上的模型参数。因此,在加载模型之后,还需要调用.cuda()函数将其移动到GPU上。
阅读全文