if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) for val_data in val_loader: val_images, val_labels = ( val_data[0].to(device), val_data[1].to(device), ) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)] y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)] auc_metric(y_pred_act, y_onehot) result = auc_metric.aggregate() auc_metric.reset() del y_pred_act, y_onehot metric_values.append(result) acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) if result > best_metric: best_metric = result best_metric_epoch = epoch + 1 torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth")) print("saved new best metric model") print( f"current epoch: {epoch + 1} current AUC: {result:.4f}" f" current accuracy: {acc_metric:.4f}" f" best AUC: {best_metric:.4f}" f" at epoch: {best_metric_epoch}" )
时间: 2024-02-14 21:31:54 浏览: 146
这段代码是在训练过程中的一个epoch结束后进行的验证步骤。在每个val_interval的倍数的epoch结束后,模型会进入评估模式(model.eval()),然后用验证集(val_loader)进行验证。
首先,创建了两个空的tensor,y_pred和y,用于存储预测结果和真实标签。
然后,对于每个val_data,将验证图像(val_images)和标签(val_labels)移动到设备上。
接下来,使用模型(model)对验证图像进行预测,并使用torch.cat将每个batch的预测结果和真实标签拼接在一起,以便计算整个验证集上的指标。
然后,对于y_onehot和y_pred_act,分别对其进行转换操作,具体实现可能在其他地方。
然后,使用auc_metric计算AUC指标,并将结果添加到metric_values列表中。
接着,计算准确率指标(acc_metric),首先使用argmax函数找到每个预测结果的最大值所在的索引,然后使用torch.eq函数将预测结果与真实标签进行比较,得到一个布尔值的tensor,最后将预测正确的个数求和并除以总样本数来计算准确率。
如果当前AUC指标比之前的最佳指标(best_metric)要好,则更新best_metric和best_metric_epoch,并保存模型参数到best_metric_model.pth文件中。
最后,打印当前epoch的信息,包括当前epoch的AUC指标、准确率、最佳AUC指标及其所在的epoch。
相关问题
with torch.no_grad(): for val_data in validate_loader: val_images, val_labels = val_data outputs = net(val_images.to(device)) # eval model only have last output layer loss = loss_function(outputs, val_labels.to(device)) val_loss += loss.item() predict_y = torch.max(outputs, dim=1)[1] acc += (predict_y == val_labels.to(device)).sum().item() val_accurate = acc / val_num val_loss = val_loss /val_num if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('[epoch %d] train_loss: %.3f test_loss: %.3f test_accuracy: %.3f' % (epoch + 1, running_loss / step, val_loss, val_accurate)) print('Finished Training')
这段代码用于在验证集上评估模型的表现,并保存表现最好的模型。
首先,我们使用with torch.no_grad()语句块来关闭梯度计算,以便在验证过程中不产生梯度。然后,我们遍历验证集中的所有mini-batch,对每个mini-batch进行预测,并计算预测结果和标签数据之间的损失函数值。
接着,我们累加当前mini-batch中的损失函数值到val_loss中,并统计当前mini-batch中正确分类的样本数,并将其累加到acc中,以便后续计算平均准确率。
在验证过程中,我们使用torch.max函数找到模型输出中概率最大的类别,并将其作为预测结果。然后,我们将预测结果和标签数据进行比较,统计正确分类的样本数。
在每个epoch结束后,我们计算当前模型在验证集上的平均准确率val_accurate和平均损失val_loss,并将其输出到屏幕上。如果当前模型在验证集上的表现优于之前的最佳表现,则将当前模型保存到指定的路径save_path中。
最后,我们输出"Finished Training"表示训练过程结束。
scheduler.step()放在 if (epoch+1) % val_interval==0之前还是之后
通常情况下,scheduler.step()应该放在 if (epoch+1) % val_interval==0之后。因为scheduler.step()的作用是更新优化器的学习率,而在训练过程中,每个epoch结束后进行验证集的验证是很常见的操作,所以在验证集验证之后再更新学习率更为合理。如果在之前更新学习率,那么当前epoch的学习率将会与下一个epoch的验证集结果产生关联,从而影响模型的训练效果。
阅读全文