resnet50 二分类 训练自己的数据集 保存模型 使用保存好的模型进行测试 代码
时间: 2023-09-15 19:21:50 浏览: 173
以下是使用ResNet50模型进行二分类任务的代码示例,包括训练和测试过程:
```python
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import ResNet50
# 训练集和验证集路径
train_data_dir = 'path/to/train/dataset'
validation_data_dir = 'path/to/validation/dataset'
# 训练集和验证集的图像尺寸
img_width, img_height = 224, 224
# 批次大小
batch_size = 32
# 训练集和验证集的样本数量
nb_train_samples = 1000
nb_validation_samples = 200
# 创建ResNet50模型,去掉头部分类层
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
# 添加自定义的分类层
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
# 构建完整模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结ResNet50的卷积层,只训练自定义分类层
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 数据增强
train_datagen = ImageDataGenerator(rescale=1. / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
batch_size=batch_size, class_mode='binary')
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
batch_size=batch_size, class_mode='binary')
# 训练模型
model.fit(train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=10, validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
# 保存模型
model.save('my_model.h5')
# 加载模型
model = tf.keras.models.load_model('my_model.h5')
# 测试模型
test_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
batch_size=batch_size, class_mode='binary')
test_loss, test_acc = model.evaluate(test_generator)
print('Test accuracy:', test_acc)
```
在上面的代码中,我们使用了Keras的ImageDataGenerator类来进行数据增强,并使用fit方法训练模型。模型训练完毕后,我们使用save方法来保存模型,使用load_model方法来加载模型。最后,我们使用evaluate方法来测试模型的性能。
阅读全文