如何保存和加载TensorFlow模型?
时间: 2024-11-22 10:23:11 浏览: 13
在TensorFlow中,你可以使用`tf.train.Saver`或者更现代的`tf.saved_model` API来保存和加载模型。以下是基本步骤:
**保存模型**:
1. 定义一个Saver实例:
```python
saver = tf.train.Saver()
```
2. 训练模型并达到你想要保存的状态:
```python
with tf.Session() as sess:
# ...训练过程...
saver.save(sess, 'path/to/save/model')
```
或者使用`tf.saved_model.save()`,如果是在TF 2.x版本及以上:
```python
model_path = 'path/to/save/model'
builder = tf.saved_model.builder.SavedModelBuilder(model_path)
# ...构建模型元数据...
builder.save()
```
**加载模型**:
1. 使用`tf.train.import_meta_graph`导入元图文件(`.meta`)和数据文件(`.data-*`):
```python
with tf.Session() as sess:
meta_file = 'path/to/model.meta'
data_files = ['path/to/model.data*']
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, tf.train.latest_checkpoint('path/to/models'))
```
2. 或者直接使用`tf.keras.models.load_model`加载Keras模型:
```python
loaded_model = tf.keras.models.load_model('path/to/model.h5')
```
阅读全文