Tensorflow模型保存与调用
时间: 2023-11-18 19:02:25 浏览: 93
TensorFlow是一个强大的机器学习库,可以用于构建和训练各种深度学习模型。在训练完一个模型后,我们通常会将模型保存到磁盘上以备将来使用。TensorFlow提供了一种简单的方法来保存和加载模型。
在 TensorFlow 中,我们可以使用 `tf.saved_model` API 将模型保存到一个文件夹中。该文件夹包含了模型的图和变量等信息。保存模型时,我们可以指定要保存的内容,例如只保存变量或只保存图等等。
以下是一个保存和加载模型的示例:
```python
import tensorflow as tf
from tensorflow import keras
# 构建一个简单的模型
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax')
])
# 训练模型
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
# 保存模型
export_path = './saved_model'
tf.saved_model.simple_save(
keras.backend.get_session(),
export_path,
inputs={'input_image': model.input},
outputs={t.name: t for t in model.outputs})
# 加载模型
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_path)
graph = tf.get_default_graph()
inputs = graph.get_tensor_by_name('input_image:0')
outputs = graph.get_tensor_by_name('dense_1/Softmax:0')
result = sess.run(outputs, feed_dict={inputs: x_test})
print(result)
```
在上面的示例中,我们首先构建了一个简单的神经网络模型并训练了它。然后我们使用 `tf.saved_model.simple_save` 将模型保存到一个文件夹中。在保存模型时,我们指定了模型的输入和输出。最后,我们使用 `tf.saved_model.loader.load` 方法加载模型并运行它。
值得注意的是,在加载模型时,我们需要使用 `tf.Session` 和 `tf.Graph` 来创建一个新的计算图,并使用 `tf.saved_model.loader.load` 方法从文件夹中加载模型。在加载模型后,我们可以使用 `tf.get_default_graph` 方法获取默认的计算图,并使用 `graph.get_tensor_by_name` 方法获取输入和输出张量的引用。
这就是 TensorFlow 中保存和加载模型的基本方法。当我们需要在生产环境中使用模型时,我们可以使用保存的模型轻松地加载和运行它。
阅读全文