怎么确保在每个 epoch 结束时,将当前 epoch 的精度和损失值添加到 epoch_acc 和 epoch_loss 数组中
时间: 2024-02-11 12:04:20 浏览: 113
你可以在每个 epoch 结束时,将当前 epoch 的精度和损失值添加到 `epoch_acc` 和 `epoch_loss` 数组中。可以将以下代码添加到每个 epoch 的末尾:
```
# 计算并保存当前 epoch 的平均精度和损失值
acc_avg = 0
loss_avg = 0
for batch_id, (img, label, bbox, landmark) in enumerate(train_loader):
img = img.to(device)
label = label.to(device).long()
bbox = bbox.to(device)
landmark = landmark.to(device)
class_out, bbox_out, landmark_out = model(img)
acc = accuracy(class_out, label)
cls_loss = class_loss(class_out, label)
box_loss = bbox_loss(bbox_out, bbox, label)
landmarks_loss = landmark_loss(landmark_out, landmark, label)
total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * box_loss + radio_landmark_loss * landmarks_loss
acc_avg += acc
loss_avg += total_loss.item()
acc_avg /= len(train_loader)
loss_avg /= len(train_loader)
epoch_acc.append(acc_avg)
epoch_loss.append(loss_avg)
```
这段代码会计算当前 epoch 的平均精度和损失值,然后将它们添加到 `epoch_acc` 和 `epoch_loss` 数组中。注意,在添加之前需要先定义 `epoch_acc` 和 `epoch_loss` 数组,例如:
```
epoch_acc = []
epoch_loss = []
```
希望可以帮到你!
阅读全文