用resnet50跑分类图像任务代码
时间: 2024-07-31 20:00:53 浏览: 105
baseline_model_resnet50_神经网络_图像分类_源码
ResNet50是一种深度残差网络,常用于图像分类等计算机视觉任务。以下是一个简单的例子,展示如何使用Keras库在Python中加载预训练的ResNet50模型,并应用于图像分类:
```python
# 导入所需的库
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
# 加载预训练的ResNet50模型(不包括顶部全连接层)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# 添加全局平均池化层和全连接层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x) # 类别数可以根据实际情况调整
predictions = Dense(num_classes, activation='softmax')(x)
# 创建新的模型,包含原始ResNet50的底部部分和我们添加的新部分
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结基础模型的权重
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
# 数据增强
datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
train_generator = datagen.flow_from_directory(train_dir, target_size=(224, 224), batch_size=batch_size, class_mode='categorical')
# 训练模型
model.fit(train_generator, epochs=num_epochs, validation_data=val_generator)
# 最后解冻全部层进行微调
for layer in model.layers:
layer.trainable = True
# 继续训练(如果需要)
model.fit(train_generator, epochs=num_epochs_tune, validation_data=val_generator)
```
在这个示例中,你需要替换`train_dir`、`val_dir`、`num_classes`、`batch_size`、`num_epochs`和`num_epochs_tune`为你实际的项目路径、类别数、批量大小、初始训练轮数以及微调轮数。
阅读全文