举一下tf.keras.callbacks.ModelCheckpoint的例子
时间: 2024-11-18 18:19:16 浏览: 36
浅谈keras.callbacks设置模型保存策略
`ModelCheckpoint`回调是一个常用的TensorFlow库中的功能,它允许你在训练过程中定期保存模型的特定检查点。这里有一个简单的例子:
```python
from tensorflow.keras.callbacks import ModelCheckpoint
# 创建一个模型实例
model = tf.keras.models.Sequential([...])
# 设置优化器、损失函数和指标
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
# 定义保存路径及检查点规则(例如,每五次epoch保存一次,并且保留最好的验证指标版本)
checkpoint_path = "path/to/save/checkpoints/weights-{epoch:02d}-{val_loss:.2f}.hdf5"
checkpointer = ModelCheckpoint(filepath=checkpoint_path,
verbose=1,
save_best_only=True,
monitor='val_loss',
mode='min')
# 开始训练
history = model.fit(x_train, y_train,
epochs=100,
validation_data=(x_val, y_val),
callbacks=[checkpointer])
```
在这个例子中,每当验证集上的`val_loss`指标有所下降时,`ModelCheckpoint`就会存储一个新的模型权重版本。如果你设置`save_best_only=True`,它只会保留性能最好的模型。这样可以防止长期训练导致模型越来越大,同时也方便后续恢复最佳状态。
阅读全文