不要额外调用train_one_epoch
时间: 2024-10-18 21:13:47 浏览: 19
Keras之fit_generator与train_on_batch用法
如果你不想在训练结束后额外调用 `train_one_epoch` 函数来绘制和保存损失曲线,你可以选择在每次迭代结束后直接处理绘图和保存步骤。这通常在你需要实时监控损失变化的情况下很有帮助。以下是修改后的部分代码段:
```python
def train_one_epoch(
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, plot_interval,
):
model.train()
device = next(model.parameters()).device
loss_history = {'main_loss': [], 'mse_loss': [], 'bpp_loss': [], 'aux_loss': []}
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
aux_optimizer.zero_grad()
out_net = model(d)
out_criterion = criterion(out_net, d)
out_criterion["loss"].backward()
if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
aux_loss = model.aux_loss()
aux_loss.backward()
aux_optimizer.step()
# 更新损失历史记录并立即绘制
if (i + 1) % plot_interval == 0:
loss_history['main_loss'].append(out_criterion["loss"].item())
loss_history['mse_loss'].append(out_criterion["mse_loss"].item())
loss_history['bpp_loss'].append(out_criterion["bpp_loss"].item())
loss_history['aux_loss'].append(aux_loss.item())
# 当前迭代结束时绘制并显示结果
draw_and_save_curve(epoch, loss_history, losses_filename, i, plot_interval)
# 可能会在这里加入其他的监控操作,如打印当前的损失值或其他指标
def draw_and_save_curve(epoch, loss_history, losses_filename, current_iter, total_iters):
#... (绘制和保存代码与之前相同)
print(
f"Iteration {current_iter+1}/{total_iters}: "
f"[Loss: {loss_history['main_loss'][-1]:.3f}, MSE: {loss_history['mse_loss'][-1]:.3f}, "
f"BPP: {loss_history['bpp_loss'][-1]:.2f}]"
)
# 调用train_one_epoch时不包含epoch参数,因为现在它不再用于每个epoch的起始处
for _ in range(num_epochs):
train_one_epoch(...)
```
这样,你可以在训练过程中看到即时的损失更新,并在每个`plot_interval`步长内保存最新的损失情况。请注意,你需要调整`draw_and_save_curve`函数以适应这种实时反馈模式。
阅读全文