tf.keras fit
时间: 2023-06-21 20:23:21 浏览: 94
`tf.keras` 中的 `fit` 是一个用于模型训练的函数,它可以训练一个 Keras 模型来拟合给定的训练数据集。`fit` 函数的语法为:
```
model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
```
其中,常用的参数包括:
- `x`:输入数据。如果模型只有一个输入,那么 `x` 的类型应为 Numpy 数组(或 Pandas DataFrame),如果模型有多个输入,那么 `x` 的类型应为列表,列表中的每个元素分别对应于模型每个输入。
- `y`:标签数据。与 `x` 类似,如果模型有多个输出,那么 `y` 的类型应为列表,列表中的每个元素分别对应于模型每个输出。
- `batch_size`:整数,指定进行梯度下降时每个 batch 包含的样本数。
- `epochs`:整数,训练的轮数。
- `verbose`:日志显示模式。0 表示不显示日志,1 表示显示进度条,2 表示每个 epoch 显示一行日志。
- `callbacks`:Keras 回调函数列表。回调函数是在训练过程中被调用的函数,例如用于保存检查点、提前停止训练、记录训练历史等。
- `validation_split`:0 到 1 之间的浮点数,用于指定训练集中作为验证集使用的比例。
- `validation_data`:用于验证的数据。与 `x` 和 `y` 类似,如果模型有多个输入和输出,那么 `validation_data` 的类型应为元组或列表。
- `shuffle`:布尔值,指定是否在每个 epoch 开始时随机打乱训练数据。
- `steps_per_epoch`:在声明一个 epoch 完成并开始下一个 epoch 之前从 generator 产生的总步数(批次样本)。
- `validation_steps`:只有在 steps_per_epoch 未指定时才有用,验证数据集上的总步数(批次样本)。
`fit` 函数在训练过程中会返回一个 `History` 对象,包含训练过程中损失和指标的变化历史。可以使用 `history.history` 属性来访问这个对象。
阅读全文