if epoch % args.print_epoch == 0: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%')
时间: 2024-03-22 22:36:58 浏览: 13
可以将上述代码与之前提到的打印训练集准确率以及记录最高训练集准确率的代码合并,形成完整的训练代码如下:
```
max_accuracy = 0.0
max_train_accuracy = 0.0
for epoch in range(args.num_epochs):
# Train the model
train(train_loader, model, criterion, optimizer, device)
# Evaluate on the training set
train_stats = evaluate(data_loader_train, model, device)
print(f"Accuracy of the network on the {len(dataset_train)} training images: {train_stats['acc1']:.1f}%")
if train_stats["acc1"] > max_train_accuracy:
max_train_accuracy = train_stats["acc1"]
print(f'Max training accuracy: {max_train_accuracy:.2f}%')
# Evaluate on the validation set
if epoch % args.print_epoch == 0:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
```
这样,每次训练过程中都会输出模型在训练集和验证集上的准确率,并记录最高训练集准确率和最高验证集准确率。