保存训练好的模型及参数 keras代码
时间: 2023-08-11 14:03:17 浏览: 102
C++ 从 HDF5 文件读取 Keras 神经网络模型和参数
在 Keras 中,可以通过 Model 类的 save() 方法来保存训练好的模型及参数,具体代码如下:
```python
from keras.models import load_model
# 创建并编译模型
model = create_model()
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10, batch_size=32)
# 保存模型及参数
model.save('my_model.h5')
# 加载模型及参数
loaded_model = load_model('my_model.h5')
```
在这个例子中,我们首先创建并编译了一个模型,然后使用 fit() 方法来训练模型。训练完成后,我们调用 save() 方法将模型及参数保存到文件 my_model.h5 中。如果需要重新加载模型及参数,可以使用 load_model() 方法。
需要注意的是,Keras 保存的模型文件包含了模型的结构、权重、优化器状态等信息,因此可以直接在不同的机器上或者不同的 Python 环境中加载和使用。
阅读全文