调用保存的.pb文件画混淆矩阵
时间: 2023-12-14 19:02:25 浏览: 74
在jupyter notebook中调用.ipynb文件方式
如果您有一个训练好的模型,可以使用以下代码调用保存的.pb文件来画混淆矩阵:
```python
import tensorflow as tf
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
# 加载保存的.pb文件
with tf.gfile.GFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# 定义输入和输出节点的名称
input_node = 'input:0'
output_node = 'output:0'
# 获取输入和输出张量
x = graph.get_tensor_by_name(input_node)
y = graph.get_tensor_by_name(output_node)
# 加载测试数据集
test_data = np.load('test_data.npy')
test_labels = np.load('test_labels.npy')
# 运行测试数据集
with tf.Session(graph=graph) as sess:
y_pred = sess.run(y, feed_dict={x: test_data})
# 计算混淆矩阵
cm = confusion_matrix(test_labels, np.argmax(y_pred, axis=1))
# 可视化混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
plt.colorbar()
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.xticks(np.arange(10))
plt.yticks(np.arange(10))
plt.show()
```
上述代码中,我们首先加载保存的.pb文件,然后获取输入和输出节点的名称以及对应的张量。接下来,我们加载测试数据集并运行模型预测,最后计算混淆矩阵并可视化。
阅读全文