调用保存的.pb文件画混淆矩阵
时间: 2023-12-14 16:02:40 浏览: 116
以下是使用TensorFlow和sklearn库调用保存的.pb文件画混淆矩阵的示例代码:
```python
import tensorflow as tf
from sklearn.metrics import confusion_matrix
import numpy as np
# 加载保存的.pb文件
with tf.gfile.GFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 加载图
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
# 获取输入和输出张量
input_tensor = graph.get_tensor_by_name('input_tensor:0')
output_tensor = graph.get_tensor_by_name('output_tensor:0')
# 加载测试数据
test_data = np.load('test_data.npy')
test_labels = np.load('test_labels.npy')
# 运行模型
with tf.Session(graph=graph) as sess:
predictions = sess.run(output_tensor, feed_dict={input_tensor: test_data})
# 计算混淆矩阵
cm = confusion_matrix(test_labels, np.argmax(predictions, axis=1))
# 打印混淆矩阵
print('Confusion Matrix:')
print(cm)
```
其中,`model.pb`是保存的.pb文件的文件名,`input_tensor`和`output_tensor`是模型的输入和输出张量的名称,`test_data`和`test_labels`是测试数据和标签。使用`np.argmax(predictions, axis=1)`将模型的输出转换为类别预测,然后计算混淆矩阵并打印。
阅读全文