ModelCheckpoint
时间: 2023-08-16 13:10:04 浏览: 147
`ModelCheckpoint` 是 Keras 中提供的一个回调函数,它用于在每个训练周期后保存模型的权重或整个模型。
以下是一个使用 `ModelCheckpoint` 回调函数来保存模型权重的示例:
```
from tensorflow.keras.callbacks import ModelCheckpoint
# 创建一个 ModelCheckpoint 回调函数
checkpoint = ModelCheckpoint('model_weights.h5', save_weights_only=True, save_best_only=True, monitor='val_loss', mode='min', verbose=1)
# 在模型训练期间将 ModelCheckpoint 回调函数传递给 fit 函数
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[checkpoint])
```
在上面的代码中,我们创建了一个名为 `checkpoint` 的回调函数,并将其传递给 Keras 的 `fit` 函数中的 `callbacks` 参数。`ModelCheckpoint` 回调函数将在每个训练周期结束后检查验证集的损失值,并将模型权重保存到文件 `model_weights.h5` 中,当且仅当验证集的损失值最小时才会保存。
你也可以使用 `ModelCheckpoint` 回调函数来保存整个模型,将 `save_weights_only` 参数设置为 `False` 即可。例如:
```
from tensorflow.keras.callbacks import ModelCheckpoint
# 创建一个 ModelCheckpoint 回调函数
checkpoint = ModelCheckpoint('model.h5', save_weights_only=False, save_best_only=True, monitor='val_loss', mode='min', verbose=1)
# 在模型训练期间将 ModelCheckpoint 回调函数传递给 fit 函数
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[checkpoint])
```
在上面的代码中,我们将 `save_weights_only` 参数设置为 `False`,这将在每个训练周期结束后检查验证集的损失值,并将整个模型保存到文件 `model.h5` 中,当且仅当验证集的损失值最小时才会保存。
阅读全文