基于已训练好的CIFAR模型将500张32*32像素的图片进行识别并将图片和标签组成键值对保存的代码
时间: 2024-02-15 08:02:59 浏览: 212
以下是一个基于已训练好的CIFAR-10模型对500张32x32像素的图片进行分类并保存键值对的示例代码:
```python
import tensorflow as tf
import numpy as np
import os
# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# 定义类别标签
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# 加载已训练好的模型
model = tf.keras.models.load_model('cifar_model.h5')
# 预测图片分类并保存键值对
images = x_test[:500]
labels = y_test[:500]
predictions = model.predict(images)
predictions = np.argmax(predictions, axis=1)
if not os.path.exists('image_labels'):
os.makedirs('image_labels')
for i in range(len(images)):
image_name = f"image_{i}.png"
image_label = class_names[labels[i][0]]
prediction_label = class_names[predictions[i]]
txt_file = open(f"image_labels/{image_name}.txt", "w")
txt_file.write(f"Image: {image_name}\n")
txt_file.write(f"True Label: {image_label}\n")
txt_file.write(f"Predicted Label: {prediction_label}\n")
txt_file.close()
```
代码中加载了已训练好的CIFAR-10模型,并使用模型对前500张测试图片进行分类预测,最终将图片名、真实标签和预测标签保存在文件夹`image_labels`中的txt文件中。注意,需要先将已训练好的模型保存为`cifar_model.h5`文件。
阅读全文