tensorflow mobilenetv3 二分类 训练自己的数据集 防止过拟合 保存训练结果 预测 代码
时间: 2023-07-11 09:27:08 浏览: 92
好的,以下是您需要的代码:
1. 数据集准备和预处理
```python
import tensorflow as tf
import numpy as np
import os
# 定义数据集目录和类别数
data_dir = 'path/to/dataset'
num_classes = 2
# 定义图像大小和批次大小
image_size = 224
batch_size = 32
# 定义数据增强器,包括随机裁剪、水平翻转和随机亮度、对比度和饱和度变换
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomCrop(image_size, image_size),
tf.keras.layers.experimental.preprocessing.RandomFlip(mode='horizontal'),
tf.keras.layers.experimental.preprocessing.RandomContrast(0.1),
tf.keras.layers.experimental.preprocessing.RandomBrightness(0.1),
tf.keras.layers.experimental.preprocessing.RandomSaturation(0.1),
])
# 定义训练集和验证集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='training',
seed=123,
image_size=(image_size, image_size),
batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='validation',
seed=123,
image_size=(image_size, image_size),
batch_size=batch_size
)
# 对训练集应用数据增强器
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))
# 对验证集进行缓存和预取以加速读取速度
val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
```
2. 模型构建
```python
from tensorflow.keras.applications import MobileNetV3Small
# 加载 MobileNetV3 模型,不包括分类层
base_model = MobileNetV3Small(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3))
# 冻结模型的所有层,以便只训练新添加的分类层
base_model.trainable = False
# 添加全局平均池化层和分类层
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(num_classes, activation='softmax')
# 构建完整模型
model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])
```
3. 模型编译和训练
```python
# 编译模型,选择损失函数、优化器和评价指标
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# 训练模型,选择训练轮数和回调函数
epochs = 10
history = model.fit(train_ds,
epochs=epochs,
validation_data=val_ds,
callbacks=[tf.keras.callbacks.ModelCheckpoint('model.h5', save_best_only=True)])
```
4. 模型预测和保存
```python
# 加载保存的最优模型
model = tf.keras.models.load_model('model.h5')
# 对单张图像进行预测
img_path = 'path/to/image'
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(image_size, image_size))
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # 扩展维度以匹配模型输入
predictions = model.predict(img_array)
print(predictions)
# 保存整个模型为 SavedModel 格式
tf.saved_model.save(model, 'saved_model')
```
阅读全文