for epoch in range(num_epochs): model.train() train_loss = 0 train_mae = 0 for batch_data, batch_target in train_dataloader: optimizer.zero_grad() output = model(batch
时间: 2023-12-06 20:44:00 浏览: 171
_data) # 前向传播
loss = loss_fn(output, batch_target) # 计算loss
train_loss += loss.item()
mae = mean_absolute_error(output.detach().cpu().numpy(), batch_target.detach().cpu().numpy()) # 计算MAE
train_mae += mae
loss.backward() # 反向传播
optimizer.step() # 更新梯度
train_loss /= len(train_dataloader)
train_mae /= len(train_dataloader)
model.eval() # 切换到评估模式
eval_loss = 0
eval_mae = 0
with torch.no_grad():
for batch_data, batch_target in eval_dataloader:
output = model(batch_data)
loss = loss_fn(output, batch_target)
eval_loss += loss.item()
mae = mean_absolute_error(output.detach().cpu().numpy(), batch_target.detach().cpu().numpy())
eval_mae += mae
eval_loss /= len(eval_dataloader)
eval_mae /= len(eval_dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.4f}, Eval Loss: {eval_loss:.4f}, Eval MAE: {eval_mae:.4f}")
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
torch.save(model.state_dict(), save_path)
print("Save model successfully!")
阅读全文
相关推荐
![pth](https://img-home.csdnimg.cn/images/20250102104920.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)