调用保存的.pb文件画混淆矩阵
时间: 2023-12-14 12:02:27 浏览: 54
首先,需要加载 TensorFlow 和其他必要的库:
```python
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
```
接着,加载保存的.pb文件和测试数据:
```python
# 加载保存的.pb文件
with tf.gfile.GFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# 加载测试数据
test_images = np.load('test_images.npy')
test_labels = np.load('test_labels.npy')
```
然后,定义一个函数来计算混淆矩阵并绘制它:
```python
def plot_confusion_matrix(test_images, test_labels):
with tf.Session() as sess:
# 获取输入和输出张量
input_tensor = sess.graph.get_tensor_by_name('input:0')
output_tensor = sess.graph.get_tensor_by_name('output:0')
# 对测试集进行预测
predicted_labels = sess.run(output_tensor, feed_dict={input_tensor: test_images})
# 计算混淆矩阵
cm = confusion_matrix(test_labels, predicted_labels.argmax(axis=1))
# 绘制混淆矩阵
plt.figure(figsize=(8, 8))
plt.imshow(cm, cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
plt.xticks(range(10))
plt.yticks(range(10))
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
```
最后,调用函数绘制混淆矩阵:
```python
plot_confusion_matrix(test_images, test_labels)
```
这将绘制一个混淆矩阵,其中每个单元格表示预测为某个类别的图像的数量。对角线上的单元格表示正确分类的图像,非对角线上的单元格表示错误分类的图像。
阅读全文