from pytorch-lightning.callbacks import ModelCheckpoint
时间: 2024-10-22 20:09:47 浏览: 34
Python库 | pytorch-lightning-0.8.1.tar.gz
5星 · 资源好评率100%
`from pytorch_lightning.callbacks import ModelCheckpoint` 是 PyTorch Lightning 库中用于模型检查点保存的回调函数。在深度学习训练过程中,模型checkpoint是一个重要的组件,它允许你在训练期间保存模型的状态,以便在遇到中断(如断电、资源限制等)时能够恢复训练,或者评估训练效果后选择最好的模型。
当你导入 `ModelCheckpoint` 之后,可以在 PyTorch Lightning 训练循环的 `Trainer` 中配置它。你可以设置诸如自动保存最佳验证分数、定期保存、最大保存的数量等参数。它的基本用法通常包括创建一个 `ModelCheckpoint` 对象,并将其添加到 `trainer.callback_metrics` 或者直接作为参数传递给 `Trainer` 的 `callbacks` 参数。
示例:
```python
from pytorch_lightning.callbacks import ModelCheckpoint
# 创建 ModelCheckpoint 实例
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints', # 检查点文件夹路径
filename='{epoch}-{val_loss:.2f}', # 根据验证损失命名文件
monitor='val_loss', # 监视指标
save_top_k=1, # 只保留最好的 k 个模型
mode='min', # 最小化验证损失
)
# 在 Trainer 中使用
trainer = pl.Trainer(callbacks=[checkpoint_callback])
```
阅读全文