基于CIFAR模型将500张32*32像素的图片进行识别并将图片和标签组成键值对保存的代码
时间: 2024-02-15 17:02:58 浏览: 64
通过 代码 识别 图片
以下是一个基于CIFAR-10数据集对500张32x32像素的图片进行分类并保存键值对的示例代码:
```python
import tensorflow as tf
import numpy as np
import os
# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# 定义类别标签
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# 定义模型
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
# 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10)
# 预测图片分类并保存键值对
images = x_test[:500]
labels = y_test[:500]
predictions = model.predict(images)
predictions = np.argmax(predictions, axis=1)
if not os.path.exists('image_labels'):
os.makedirs('image_labels')
for i in range(len(images)):
image_name = f"image_{i}.png"
image_label = class_names[labels[i][0]]
prediction_label = class_names[predictions[i]]
txt_file = open(f"image_labels/{image_name}.txt", "w")
txt_file.write(f"Image: {image_name}\n")
txt_file.write(f"True Label: {image_label}\n")
txt_file.write(f"Predicted Label: {prediction_label}\n")
txt_file.close()
```
代码中定义了一个包含两个卷积层和两个池化层的简单CNN模型,对CIFAR-10数据集进行训练,并使用模型对前500张测试图片进行分类预测,最终将图片名、真实标签和预测标签保存在文件夹`image_labels`中的txt文件中。
阅读全文