如何调用saver.save保存下来的深度模型?
时间: 2024-04-25 07:03:56 浏览: 94
可以使用 TensorFlow 中的 `tf.train.import_meta_graph` 方法导入模型的 meta graph,然后使用 `tf.train.Saver` 类的 `restore` 方法将保存的变量值恢复到新的 `tf.Session` 中。具体实现方式如下:
```python
import tensorflow as tf
# 定义需要保存的变量
x = tf.placeholder(tf.float32, [None, 784], name='x')
y = tf.placeholder(tf.float32, [None, 10], name='y')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
logits = tf.add(tf.matmul(x, W), b, name='logits')
# 定义 Saver
saver = tf.train.Saver()
# 创建 Session 并训练模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ... 省略训练过程 ...
# 保存模型
saver.save(sess, './my_model.ckpt')
# 导入模型 meta graph
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./my_model.ckpt.meta')
saver.restore(sess, './my_model.ckpt')
# 获取模型中的变量
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('x:0')
y = graph.get_tensor_by_name('y:0')
logits = graph.get_tensor_by_name('logits:0')
# 在导入的模型上进行预测
# ... 省略预测过程 ...
```
阅读全文