手写数字识别,如何调用saver.save保存下来的深度模型,并测试训练效果。请列举代码
时间: 2024-05-21 21:11:38 浏览: 111
Sure,以下是利用 TensorFlow 实现手写数字识别的代码示例,包括保存模型和测试训练效果的代码段:
1. 保存模型
使用 TensorFlow 的 Saver 类,可以保存训练好的模型。具体步骤如下:
# 定义模型变量与计算图
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# 定义训练操作
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 创建 Saver 对象
saver = tf.train.Saver()
# 训练模型及时保存
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
# 保存模型
saver.save(sess, 'mnist_model/mnist_model.ckpt', global_step=i)
2. 加载模型并测试
可以使用 tf.train.import_meta_graph() 方法加载包含计算图结构的.meta 文件,然后使用 Saver.restore() 方法加载包含变量值的.ckpt 文件,最后运行模型验证准确率。
# 加载模型,构建计算图
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_test = x_test.reshape(-1, 784)
x_test = x_test.astype('float32') / 255
y_test = tf.keras.utils.to_categorical(y_test, 10)
with tf.Session() as sess:
# 加载计算图
saver = tf.train.import_meta_graph('mnist_model/mnist_model.ckpt-900.meta')
saver.restore(sess, tf.train.latest_checkpoint('mnist_model/'))
# 获取模型输入与输出
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('x:0')
y = graph.get_tensor_by_name('Softmax:0')
accuracy = graph.get_tensor_by_name('Mean:0')
y_ = graph.get_tensor_by_name('y_:0')
# 运行模型验证准确率
acc = sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
print(acc)
希望这些代码对你有所帮助。
阅读全文