tensorflow mobilenetv2 二分类 训练自己的数据集 保存训练结果 预测 代码
时间: 2023-07-11 13:25:13 浏览: 206
下面是一个简单的 TensorFlow 2.x 的 MobileNetV2 模型的训练、保存和预测代码,以及如何在自己的数据集上进行二分类训练:
```python
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 设置训练和验证数据集的路径
train_data_dir = '/path/to/train'
val_data_dir = '/path/to/validation'
# 设置图像大小和批次大小
img_width, img_height = 224, 224
batch_size = 32
# 数据增强设置,可以根据需要进行更改
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1./255)
# 加载训练和验证数据集
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')
val_generator = val_datagen.flow_from_directory(
val_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')
# 加载 MobileNetV2 模型,并移除最后一层全连接层
base_model = MobileNetV2(include_top=False, weights='imagenet', input_shape=(img_width, img_height, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
# 添加新的全连接层
predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结 MobileNetV2 的卷积层,只训练新的全连接层
for layer in base_model.layers:
layer.trainable = False
# 编译模型,设置损失函数和优化器
model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_generator, epochs=10, validation_data=val_generator)
# 保存模型
model.save('/path/to/model.h5')
# 加载模型进行预测
loaded_model = tf.keras.models.load_model('/path/to/model.h5')
predictions = loaded_model.predict(image)
```
在上述代码中,我们首先加载 MobileNetV2 模型,并移除最后一层全连接层。然后,我们添加一个新的全连接层,并编译模型。接下来,我们训练模型,并保存训练结果。最后,我们加载模型并进行预测。
在训练模型时,我们使用了数据增强来提高模型的泛化能力。同时,我们冻结了 MobileNetV2 的卷积层,只训练新的全连接层。这可以加快模型的训练速度,并提高模型的准确性。
注意,这里的例子是针对二分类任务的。如果您的任务是多分类,您需要相应地更改模型的输出层和损失函数。
阅读全文