print('lr: ', optimizer.param_groups[0]['lr']) save_path = 'snapshots/{}/'.format(opt.train_save) os.makedirs(save_path, exist_ok=True) if (epoch+1) % 1 == 0: meanloss = test(model, opt.test_path) if meanloss < best_loss: print('new best loss: ', meanloss) best_loss = meanloss torch.save(model.state_dict(), save_path + 'TransFuse-%d.pth' % epoch) print('[Saving Snapshot:]', save_path + 'TransFuse-%d.pth'% epoch) return best_loss
时间: 2024-04-29 20:19:41 浏览: 164
PyTorch中model.zero_grad()和optimizer.zero_grad()用法
这段代码主要是保存模型和更新最佳损失。首先输出当前学习率,然后创建一个保存模型的路径,并在需要时创建保存模型的目录。接着,如果当前训练轮数是1的倍数,就调用test()函数进行测试,并计算测试集上的平均损失meanloss。如果meanloss小于当前的最佳损失best_loss,则更新best_loss的值,并保存模型的参数到文件中。最后,返回最佳损失的值。
阅读全文