if epoch % args.print_epoch == 0: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') train_stats = evaluate(data_loader_train, model, device) print(f"Accuracy of the network on the {len(dataset_train)} train images: {train_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, train_stats["acc1"]) print(f'train Max accuracy: {max_accuracy:.2f}%')
时间: 2024-03-22 09:36:59 浏览: 119
可以在每次训练过程中加入以下代码来记录并输出模型在训练集上的准确率,并更新最高训练集准确率:
```
if epoch % args.print_epoch == 0:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
train_stats = evaluate(data_loader_train, model, device)
print(f"Accuracy of the network on the {len(dataset_train)} train images: {train_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, train_stats["acc1"])
print(f'train Max accuracy: {max_accuracy:.2f}%')
```
其中,`train_stats`记录了当前训练过程中模型在训练集上的准确率,每次训练完成后,判断当前的训练集准确率是否超过了最高训练集准确率,如果超过了,则更新最高训练集准确率并输出。
阅读全文
相关推荐



















