ef train(args, model, train_loader, test_loader, boardio, textio, checkpoint): #learnable_params = filter(lambda p: p.requires_grad, model.parameters()) #optimizer = torch.optim.Adam(learnable_params) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150], gamma=0.1) if checkpoint is not None: min_loss = checkpoint['min_loss'] optimizer.load_state_dict(checkpoint['optimizer']) best_test_loss = np.inf best_test_mse_ab = np.inf best_test_rmse_ab = np.inf best_test_mae_ab = np.inf best_test_r_mse_ab = np.inf best_test_r_rmse_ab = np.inf best_test_r_mae_ab = np.inf best_test_t_mse_ab = np.inf best_test_t_rmse_ab = np.inf best_test_t_mae_ab = np.inf for epoch in range(args.epochs): train_loss, train_mse_ab, train_mae_ab, train_rotations_ab, train_translations_ab, train_rotations_ab_pred, \ train_translations_ab_pred, train_eulers_ab, = train_one_epoch(args.device, model, train_loader, optimizer) test_loss, test_mse_ab, test_mae_ab, test_rotations_ab, test_translations_ab, test_rotations_ab_pred, \ test_translations_ab_pred, test_eulers_ab = test_one_epoch(args.device, model, test_loader)设置动态学习率
时间: 2024-04-28 07:19:06 浏览: 155
017-define-func-2-args.rar_The Test
这段代码是用来设置动态学习率的。它使用了PyTorch中的lr_scheduler模块,具体来说,使用了MultiStepLR策略。这个策略会在训练过程中根据指定的milestones(里程碑)来调整学习率,每次乘以gamma(衰减因子)。在这个代码中,milestones被设置为[50, 100, 150],gamma被设置为0.1,表示当训练到第50、100、150个epoch时,学习率会分别乘以0.1。这样做可以让模型在训练初期使用较大的学习率,快速收敛,而在训练后期使用较小的学习率,细致调整,避免过拟合。
阅读全文