MNIST数据集手写体数字识别的例子,目前显示结果为“Accuracy:0.9352”,说明有些手写体的图片识别错误,如第8张图片,如下图所示,将5识别成了6。修改程序代码,记录所有识别错误的图片信息,并且将所有识别错误的图片单独保存在一个指定的文件夹内。
时间: 2024-05-04 16:17:50 浏览: 8
好的,我们可以在程序中添加代码来记录所有识别错误的图片信息,并将这些图片单独保存在一个指定的文件夹中。以下是修改后的程序代码:
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import cv2
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 数据预处理
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
# 创建模型
model = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(10, activation="softmax"),
]
)
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# 训练模型
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.1)
# 在测试集上评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print("Accuracy:", test_acc)
# 创建保存错误图片的文件夹
if not os.path.exists("error_images"):
os.makedirs("error_images")
# 预测测试集中的每张图片,并保存错误图片
for i in range(len(x_test)):
img = x_test[i]
true_label = y_test[i]
pred_label = np.argmax(model.predict(np.expand_dims(img, 0))[0])
if pred_label != true_label:
cv2.imwrite(f"error_images/{i}_true_{true_label}_pred_{pred_label}.png", img * 255)
```
在这个修改后的程序中,我们创建了一个名为“error_images”的文件夹来存储所有识别错误的图片。在预测测试集中的每个图片时,我们使用np.argmax来找到最有可能的预测标签,并将预测标签与真实标签进行比较。如果预测标签与真实标签不同,我们就将该图片保存到“error_images”文件夹中,文件名包含图片的索引、真实标签和预测标签。