if opt.model == 'GANet11': disp1, disp2 = model(input1, input2) disp0 = (disp1 + disp2)/2. if opt.kitti or opt.kitti2015: loss = 0.4 * F.smooth_l1_loss(disp1[mask], target[mask], reduction='mean') + 1.2 * criterion(disp2[mask], target[mask]) else: loss = 0.4 * F.smooth_l1_loss(disp1[mask], target[mask], reduction='mean') + 1.2 * F.smooth_l1_loss(disp2[mask], target[mask], reduction='mean') elif opt.model == 'GANet_deep': disp0, disp1, disp2 = model(input1, input2) if opt.kitti or opt.kitti2015: loss = 0.2 * F.smooth_l1_loss(disp0[mask], target[mask], reduction='mean') + 0.6 * F.smooth_l1_loss(disp1[mask], target[mask], reduction='mean') + criterion(disp2[mask], target[mask]) else: loss = 0.2 * F.smooth_l1_loss(disp0[mask], target[mask], reduction='mean') + 0.6 * F.smooth_l1_loss(disp1[mask], target[mask], reduction='mean') + F.smooth_l1_loss(disp2[mask], target[mask], reduction='mean') else: raise Exception("No suitable model found ...")
时间: 2024-01-28 11:03:04 浏览: 110
pdp.opt.model:使用Julia + JuMP解决PDP的模型
这段代码是一个深度学习模型的训练过程,模型有两种选择:GANet11和GANet_deep。如果选择GANet11模型,输入input1和input2将会通过模型得到两个视差图disp1和disp2,然后将两个视差图取平均值得到disp0。如果是在KITTI或KITTI2015数据集上训练,损失函数将会是0.4倍的平滑L1损失加上1.2倍的L1损失;否则损失函数将会是0.4倍的平滑L1损失加上0.6倍的平滑L1损失加上1倍的平滑L1损失。如果选择GANet_deep模型,输入input1和input2将会通过模型得到三个视差图disp0、disp1和disp2。如果是在KITTI或KITTI2015数据集上训练,损失函数将会是0.2倍的平滑L1损失加上0.6倍的平滑L1损失加上L1损失;否则损失函数将会是0.2倍的平滑L1损失加上0.6倍的平滑L1损失加上0.2倍的平滑L1损失。如果选择的不是GANet11或GANet_deep模型,则会抛出一个异常。
阅读全文