cifar-10图像分类python完整代码
时间: 2024-12-31 17:48:39 浏览: 17
CIFAR-10是一个常用的小型计算机视觉数据集,包含了10个类别共6万张32x32像素的彩色图片。Python中实现CIFAR-10图像分类的一个完整例子通常会使用深度学习库如TensorFlow或PyTorch。这里我以Keras库为例,给出一个简化的ResNet50模型训练示例:
```python
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10
# 加载数据并预处理
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # 归一化到[0, 1]
# 使用预训练的ResNet50模型提取特征
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
x_train_features = base_model.predict(x_train)
x_test_features = base_model.predict(x_test)
# 添加全局平均池化层和全连接层
x_train_features = GlobalAveragePooling2D()(x_train_features)
x_test_features = GlobalAveragePooling2D()(x_test_features)
predictions = Dense(10, activation='softmax')(x_train_features)
# 创建新的模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结预训练模型的权重
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
history = model.fit(x_train_features, y_train, epochs=10, validation_split=0.1)
# 评估模型
test_loss, test_acc = model.evaluate(x_test_features, y_test)
print(f"Test accuracy: {test_acc}")
阅读全文